From 8bdbfca68287232e5bf5793145f987569ecd312e Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Fri, 6 Jun 2025 14:39:20 +0800 Subject: [PATCH] v4.0 update. (#2371) --- CHANGELOG.md | 62 +- CMakeLists.txt | 10 +- Doxyfile | 4 +- README.md | 61 +- ..._gmma_ss_warpspecialized_with_prefetch.hpp | 17 +- .../65_distributed_gemm.cu | 22 +- ...zed_grouped_gemm_with_blockwise_scaling.cu | 10 +- ...th_blockwise_scaling_with_sparse_groups.cu | 10 +- .../hopper_fp8_commandline.hpp | 13 +- .../blackwell_gemm_streamk.cu | 3 +- .../75_blackwell_grouped_gemm.cu | 2 +- .../75_blackwell_grouped_gemm_block_scaled.cu | 2 +- .../76_blackwell_conv_dgrad.cu | 6 +- .../76_blackwell_conv_fprop.cu | 6 +- .../76_blackwell_conv_wgrad.cu | 2 +- .../77_blackwell_fmha/77_blackwell_fmha.cu | 221 ++- .../77_blackwell_fmha/77_blackwell_mla.cu | 10 +- examples/77_blackwell_fmha/CMakeLists.txt | 19 + .../collective/fmha_fusion.hpp | 4 +- ..._fmha_fwd_epilogue_tma_warpspecialized.hpp | 14 +- ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 59 +- ...00_fmha_bwd_kernel_tma_warpspecialized.hpp | 112 +- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 5 +- .../kernel/sm100_fmha_mla_reduction.hpp | 4 +- .../reference/fmha_fwd_reference.hpp | 2 +- .../reference/reference_abs_error.hpp | 2 + ...9d_blackwell_geforce_nvfp4_grouped_gemm.cu | 2 +- .../82_blackwell_distributed_gemm.cu | 22 +- .../REQUIREMENTS.md | 4 +- examples/88_hopper_fmha/88_hopper_fmha.cu | 1192 +++++++++++++ examples/88_hopper_fmha/CMakeLists.txt | 50 + examples/88_hopper_fmha/README.md | 77 + ...mha_collective_bwd_tma_warpspecialized.hpp | 863 ++++++++++ .../collective/fmha_collective_load.hpp | 140 ++ .../collective/fmha_collective_softmax.hpp | 305 ++++ .../collective/fmha_collective_tma.hpp | 526 ++++++ .../fmha_collective_tma_warpspecialized.hpp | 560 +++++++ .../88_hopper_fmha/collective/fmha_common.hpp | 245 +++ .../collective/fmha_epilogue.hpp | 156 ++ .../collective/fmha_epilogue_bwd.hpp | 157 ++ .../88_hopper_fmha/collective/fmha_fusion.hpp | 283 ++++ .../device/device_universal.hpp | 278 +++ .../88_hopper_fmha/device/fmha_device_bwd.hpp | 299 ++++ .../kernel/fmha_kernel_builder.hpp | 158 ++ .../kernel/fmha_kernel_bwd_convert.hpp | 143 ++ .../kernel/fmha_kernel_bwd_sum_OdO.hpp | 134 ++ .../88_hopper_fmha/kernel/fmha_kernel_tma.hpp | 222 +++ .../fmha_kernel_tma_warpspecialized.hpp | 418 +++++ .../88_hopper_fmha/kernel/fmha_options.hpp | 81 +- .../kernel/fmha_tile_scheduler.hpp | 204 +++ .../reference/fmha_bwd_reference.hpp | 357 ++++ .../reference/fmha_reference.hpp | 156 ++ .../reference/reference_abs_error.hpp | 129 ++ examples/CMakeLists.txt | 1 + examples/cute/tutorial/CMakeLists.txt | 4 + examples/cute/tutorial/tiled_copy_if.cu | 297 ++++ .../python/CuTeDSL/ampere/smem_allocator.py | 200 +++ .../python/CuTeDSL/cute/ffi/CMakeLists.txt | 51 + .../python/CuTeDSL/cute/ffi/jit_argument.py | 305 ++++ examples/python/CuTeDSL/cute/ffi/tensor.cpp | 82 + examples/python/CuTeDSL/hopper/dense_gemm.py | 1486 +++++++++++++++++ .../CuTeDSL/notebooks/hello_world.ipynb | 11 +- include/cute/algorithm/axpby.hpp | 5 +- include/cute/algorithm/cooperative_copy.hpp | 1 - include/cute/algorithm/copy.hpp | 120 +- include/cute/algorithm/prefetch.hpp | 11 +- include/cute/algorithm/tensor_algorithms.hpp | 12 + include/cute/algorithm/tensor_reduce.hpp | 107 ++ include/cute/atom/copy_atom.hpp | 54 +- include/cute/pointer_base.hpp | 69 + .../cutlass/arch/grid_dependency_control.h | 3 +- ...100_implicit_gemm_umma_warpspecialized.hpp | 1 - ..._implicit_gemm_gmma_ss_warpspecialized.hpp | 17 +- .../conv/device/conv_universal_adapter.hpp | 25 + .../cutlass/detail/blockwise_scale_layout.hpp | 19 +- include/cutlass/detail/layout.hpp | 8 +- .../collective/builders/sm100_builder.inl | 3 + .../collective/builders/sm90_builder.inl | 3 + .../collective/builders/sm90_common.inl | 8 + .../cutlass/epilogue/collective/detail.hpp | 10 + .../collective/sm100_epilogue_nosmem.hpp | 42 +- .../collective/sm70_epilogue_vectorized.hpp | 12 +- ...90_visitor_compute_tma_warpspecialized.hpp | 20 +- .../sm90_visitor_load_tma_warpspecialized.hpp | 70 +- ...sm90_visitor_store_tma_warpspecialized.hpp | 10 +- include/cutlass/epilogue/thread/activation.h | 17 + .../threadblock/fusion/visitor_load.hpp | 8 +- .../device/dist_gemm_universal_wrapper.hpp | 2 +- .../distributed/device/full_barrier.hpp | 2 +- .../sm100_blockscaled_sparse_umma_builder.inl | 2 +- .../sm100_blockscaled_umma_builder.inl | 2 +- .../builders/sm100_pipeline_carveout.inl | 6 + .../builders/sm100_sparse_umma_builder.inl | 2 +- .../builders/sm100_umma_builder.inl | 2 +- .../builders/sm120_blockwise_mma_builder.inl | 264 +++ .../collective/builders/sm120_mma_builder.inl | 2 + .../gemm/collective/collective_builder.hpp | 1 + .../gemm/collective/collective_mma.hpp | 2 + ..._blockscaled_mma_array_warpspecialized.hpp | 190 ++- .../sm100_blockscaled_mma_warpspecialized.hpp | 7 +- ...blockscaled_sparse_mma_warpspecialized.hpp | 5 +- .../sm100_mma_array_warpspecialized.hpp | 86 +- ...rray_warpspecialized_blockwise_scaling.hpp | 99 +- ...100_mma_array_warpspecialized_emulated.hpp | 21 +- .../collective/sm100_mma_warpspecialized.hpp | 1 - ..._mma_warpspecialized_blockwise_scaling.hpp | 97 +- .../sm100_mma_warpspecialized_emulated.hpp | 17 +- .../sm100_sparse_mma_warpspecialized.hpp | 1 - .../sm120_blockscaled_mma_array_tma.hpp | 17 +- .../collective/sm120_blockscaled_mma_tma.hpp | 9 +- .../sm120_blockscaled_sparse_mma_tma.hpp | 8 +- .../sm120_mma_array_tma_blockwise_scaling.hpp | 1001 +++++++++++ .../cutlass/gemm/collective/sm120_mma_tma.hpp | 9 +- .../sm120_mma_tma_blockwise_scaling.hpp | 779 +++++++++ .../gemm/collective/sm120_sparse_mma_tma.hpp | 36 +- .../gemm/collective/sm70_mma_twostage.hpp | 1 - .../gemm/collective/sm80_mma_multistage.hpp | 7 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 143 +- ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 17 +- ..._array_tma_gmma_ss_warpspecialized_fp8.hpp | 11 +- ..._warpspecialized_fp8_blockwise_scaling.hpp | 63 +- ...mma_multistage_gmma_rs_warpspecialized.hpp | 37 +- ...mma_multistage_gmma_ss_warpspecialized.hpp | 21 +- .../sm90_mma_tma_gmma_rs_warpspecialized.hpp | 49 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 115 +- .../gemm/collective/sm90_mma_tma_gmma_ss.hpp | 11 +- .../sm90_mma_tma_gmma_ss_warpspecialized.hpp | 17 +- ...90_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 11 +- ..._warpspecialized_fp8_blockwise_scaling.hpp | 35 +- ...sparse_mma_tma_gmma_ss_warpspecialized.hpp | 21 +- ...se_mma_tma_gmma_ss_warpspecialized_fp8.hpp | 21 +- include/cutlass/gemm/dispatch_policy.hpp | 84 +- .../gemm/group_array_problem_shape.hpp | 4 +- .../sm100_gemm_array_tma_warpspecialized.hpp | 141 +- ...rray_tma_warpspecialized_mma_transform.hpp | 7 +- .../kernel/sm100_gemm_tma_warpspecialized.hpp | 36 + .../sm100_sparse_gemm_tma_warpspecialized.hpp | 36 + .../kernel/sm100_tile_scheduler_group.hpp | 15 +- ...0_gemm_tma_warpspecialized_cooperative.hpp | 3 +- ...sm90_gemm_tma_warpspecialized_pingpong.hpp | 3 +- .../gemm/kernel/sm90_tile_scheduler_group.hpp | 39 +- .../kernel/sm90_tile_scheduler_stream_k.hpp | 1 - .../gemm/kernel/tile_scheduler_params.h | 48 +- include/cutlass/numeric_conversion.h | 118 +- .../cpp/blackwell_cluster_launch_control.md | 6 +- media/docs/cpp/cute/02_layout_algebra.md | 4 +- media/docs/cpp/cute/0y_predication.md | 328 ++-- media/docs/pythonDSL/cute_dsl.rst | 12 +- .../cute_dsl_general/autotuning_gemm.rst | 4 - .../pythonDSL/cute_dsl_general/debugging.rst | 4 - .../cute_dsl_general/dsl_code_generation.rst | 4 - .../cute_dsl_general/dsl_control_flow.rst | 5 +- .../cute_dsl_general/dsl_dynamic_layout.rst | 4 - .../cute_dsl_general/dsl_introduction.rst | 5 +- .../dsl_jit_arg_generation.rst | 7 +- .../cute_dsl_general/dsl_jit_caching.rst | 6 +- .../framework_integration.rst | 16 +- media/docs/pythonDSL/faqs.rst | 3 +- media/docs/pythonDSL/limitations.rst | 3 - media/docs/pythonDSL/overview.rst | 2 +- media/docs/pythonDSL/quick_start.rst | 9 + .../CuTeDSL/base_dsl/_mlir_helpers/arith.py | 2 +- python/CuTeDSL/base_dsl/ast_helpers.py | 8 +- python/CuTeDSL/base_dsl/ast_preprocessor.py | 105 +- python/CuTeDSL/base_dsl/dsl.py | 22 +- python/CuTeDSL/base_dsl/jit_executor.py | 51 +- python/CuTeDSL/base_dsl/runtime/cuda.py | 2 +- python/CuTeDSL/base_dsl/typing.py | 45 +- python/CuTeDSL/cutlass/cute/core.py | 240 ++- python/CuTeDSL/cutlass/cute/nvgpu/common.py | 5 + .../CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py | 27 +- python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py | 9 +- .../cutlass/cute/nvgpu/warpgroup/mma.py | 27 +- python/CuTeDSL/cutlass/cute/runtime.py | 2 +- python/CuTeDSL/cutlass/cute/testing.py | 15 +- python/CuTeDSL/cutlass/cute/typing.py | 24 +- python/CuTeDSL/cutlass/utils/pipeline.py | 99 +- .../CuTeDSL/cutlass/utils/smem_allocator.py | 2 +- python/CuTeDSL/cutlass_dsl/cutlass.py | 31 +- .../cutlass_dsl/cutlass_ast_decorators.py | 6 +- python/cutlass_library/__init__.py | 3 +- python/cutlass_library/emit_kernel_listing.py | 87 +- python/cutlass_library/gemm_operation.py | 10 +- python/cutlass_library/generator.py | 314 +++- python/cutlass_library/library.py | 43 +- python/cutlass_library/manifest.py | 29 +- test/self_contained_includes/CMakeLists.txt | 1 - test/unit/conv/device_3x/dgrad/CMakeLists.txt | 42 + ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 190 +++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 190 +++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ...d_implicit_gemm_f8_f8_bf16_tensorop_f32.cu | 338 ++++ ...ad_implicit_gemm_f8_f8_f16_tensorop_f32.cu | 338 ++++ ...gemm_f8_f8_f16_tensorop_f32_with_fusion.cu | 237 +++ ...ad_implicit_gemm_f8_f8_f32_tensorop_f32.cu | 338 ++++ ...rad_implicit_gemm_f8_f8_f8_tensorop_f32.cu | 338 ++++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 143 ++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ...d_implicit_gemm_f8_f8_bf16_tensorop_f32.cu | 338 ++++ ...ad_implicit_gemm_f8_f8_f16_tensorop_f32.cu | 338 ++++ ...gemm_f8_f8_f16_tensorop_f32_with_fusion.cu | 237 +++ ...ad_implicit_gemm_f8_f8_f32_tensorop_f32.cu | 338 ++++ ...rad_implicit_gemm_f8_f8_f8_tensorop_f32.cu | 338 ++++ test/unit/conv/device_3x/fprop/CMakeLists.txt | 49 + ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 246 +++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 236 +++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 292 ++++ ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 339 ++++ ...gemm_s8_s8_s32_tensorop_s32_with_fusion.cu | 378 +++++ ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 338 ++++ ..._tf32_tf32_f32_tensorop_f32_with_fusion.cu | 190 +++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 237 +++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 339 ++++ ...gemm_s8_s8_s32_tensorop_s32_with_fusion.cu | 378 +++++ ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 338 ++++ ..._tf32_tf32_f32_tensorop_f32_with_fusion.cu | 190 +++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 331 ++++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ...op_implicit_gemm_s8_s8_s32_tensorop_s32.cu | 338 ++++ ...gemm_s8_s8_s32_tensorop_s32_with_fusion.cu | 473 ++++++ ...mplicit_gemm_tf32_tf32_f32_tensorop_f32.cu | 338 ++++ ..._tf32_tf32_f32_tensorop_f32_with_fusion.cu | 190 +++ test/unit/conv/device_3x/testbed_conv.hpp | 54 +- test/unit/conv/device_3x/wgrad/CMakeLists.txt | 25 + ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 96 ++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 96 ++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 338 ++++ ..._implicit_gemm_f16_f16_f16_tensorop_f16.cu | 338 ++++ ...mm_f16_f16_f16_tensorop_f16_with_fusion.cu | 96 ++ ..._implicit_gemm_f16_f16_f32_tensorop_f32.cu | 326 ++++ test/unit/cute/core/CMakeLists.txt | 2 +- test/unit/cute/core/tensor_algs.cpp | 200 +++ test/unit/cute/core/transform.cpp | 49 - ...0_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu | 4 +- test/unit/nvrtc/thread/nvrtc_contraction.cu | 5 +- .../library/include/cutlass/library/library.h | 5 + .../library/src/grouped_gemm_operation_3x.hpp | 93 +- .../grouped_gemm_operation_profiler.h | 31 +- .../include/cutlass/profiler/options.h | 12 +- .../block_scaled_gemm_operation_profiler.cu | 2 +- .../src/blockwise_gemm_operation_profiler.cu | 2 +- tools/profiler/src/gemm_operation_profiler.cu | 2 +- .../src/grouped_gemm_operation_profiler.cu | 231 ++- tools/profiler/src/options.cu | 16 +- 254 files changed, 29751 insertions(+), 1980 deletions(-) create mode 100644 examples/88_hopper_fmha/88_hopper_fmha.cu create mode 100644 examples/88_hopper_fmha/CMakeLists.txt create mode 100644 examples/88_hopper_fmha/README.md create mode 100644 examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_collective_load.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_collective_tma.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_common.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_epilogue.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp create mode 100644 examples/88_hopper_fmha/collective/fmha_fusion.hpp create mode 100644 examples/88_hopper_fmha/device/device_universal.hpp create mode 100644 examples/88_hopper_fmha/device/fmha_device_bwd.hpp create mode 100644 examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp create mode 100644 examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp create mode 100644 examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp create mode 100644 examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp create mode 100644 examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp rename include/cute/tensor_predicate.hpp => examples/88_hopper_fmha/kernel/fmha_options.hpp (63%) create mode 100644 examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp create mode 100644 examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp create mode 100644 examples/88_hopper_fmha/reference/fmha_reference.hpp create mode 100644 examples/88_hopper_fmha/reference/reference_abs_error.hpp create mode 100644 examples/cute/tutorial/tiled_copy_if.cu create mode 100644 examples/python/CuTeDSL/ampere/smem_allocator.py create mode 100644 examples/python/CuTeDSL/cute/ffi/CMakeLists.txt create mode 100644 examples/python/CuTeDSL/cute/ffi/jit_argument.py create mode 100644 examples/python/CuTeDSL/cute/ffi/tensor.cpp create mode 100644 examples/python/CuTeDSL/hopper/dense_gemm.py create mode 100644 include/cute/algorithm/tensor_reduce.hpp create mode 100644 include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl create mode 100644 include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp create mode 100644 include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu create mode 100644 test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu create mode 100644 test/unit/cute/core/tensor_algs.cpp delete mode 100644 test/unit/cute/core/transform.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 813a04be..6fae64bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,25 +1,31 @@ # Changelog # CUTLASS 4.x -## [4.0.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-05-09) + +## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03) ### CuTe DSL * CuTe DSL, a Python DSL centered around CuTe's abstractions - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - - [DSL quick start](./media/docs/pythonDSL/quick_start.rst) - - [DSL Overview](./media/docs/pythonDSL/overview.rst) + - [DSL quick start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html) + - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html) * [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) - [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py) + - [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py) - [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py) - [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py) + - [SmemAllocator to facilitate shared memory allocation and management](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/smem_allocator.py) + - [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/jit_argument.py) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) +* API updates + - Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` ### CUTLASS C++ * Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 - - 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. + - 100f, 101f, 120f were added to support Family Specific Architecture Features which allows running the same binary on different chips belonging to the same Family (e.g. sm100) without recompiling. Note 101a is supported since CUTLASS 3.9 * Instruction shapes and redundant accumulation type have been removed from CUTLASS 3.x-style library kernel names to disambiguate kernels and shorten names. - For example: @@ -30,9 +36,25 @@ - Added non-power-of-two tile sizes. - Improved performance for K-major scale factors. - The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions. +* Support LSE output in Blackwell FMHA Forward kernel in example 77. +* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support. + - Enable runtime datatype for Blackwell grouped GEMM. Profiler support is also added. + - Enable kernel parameter exploration for Blackwell grouped GEMM - raster_order, swizzle. +* Add [Blackwell SM100 implicit GEMM conv fprop/dgrad/wgrad unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/). +* Add dynamic and preferred cluster support for convolution kernels. +* Support for Blackwell SM120 blockwise dense gemm in cutlass core library, as well as cutlass profiler. +* Fix profiler issues which cause no output or not supported error for some kernels. +* Optimization porting for BlockScaled collectives and kernel layers. +* New [Hopper FMHA example](https://github.com/NVIDIA/cutlass/tree/main/examples/88_hopper_fmha/), similar in design to the existing [Blackwell FMHA](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). +* CuTe changes: + - Rework `cute::copy_if` so that the predicate tensor is also a true CuTe Tensor rather than a lambda and introduces transform-tensors to avoid any extra register or load/store overhead in using bool-tensors. + - New [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/tiled_copy_if.cu) to show the usage of copy_if in tile copy. + - Add [CuTe C++ reduce op](https://github.com/NVIDIA/cutlass/tree/main/include/cute/algorithm/tensor_reduce.hpp). + - Add several [unit tests](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/core/tensor_algs.cpp) for CuTe tensor algorithms. * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Optimal code generation with CUDA toolkit versions 12.9. + # CUTLASS 3.x ## [3.9.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.9.2) (2025-05-03) @@ -82,7 +104,7 @@ - Sorting performance results by GFLOPs/second: Users can now sort the final performance report based on GFLOPs/second, making it easier to identify the most efficient kernels. - Exhaustive search for best kernel performance in GFLOPs/second: The profiler now searches for the best-performing kernel across a range of problem sizes, swizzle sizes, rasterization orders, and dynamic cluster configurations to maximize performance. - Performance search under a fixed GEMM shape: Enables exhaustive tuning within a fixed GEMM shape, exploring various kernel parameters to find the best configuration. - - More detailed introductions and examples to leverage this feature can be found in [profiler.md](./media/docs/cpp/profiler.md#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). + - More detailed introductions and examples to leverage this feature can be found in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#exhaustive-search-mode-and-top-k-output-ranking-according-to-performance-in-gflopss). * Support `void` as the D element in sm100 kernel epilogues. * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Optimal code generation with CUDA toolkit versions 12.8U1. @@ -101,7 +123,7 @@ - [Pipelines that implement Blackwell specific synchronization](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/cpp/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_cluster_launch_control.html) to implement dynamic persistence scheduling for [GEMMs](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. * Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that @@ -139,11 +161,11 @@ - A set of new [Hopper grouped GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. - A new [Hopper FP8 GEMM with groupwise scaling](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/cpp/quickstart.md#instantiating-a-blackwell-sm100-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/cpp/blackwell_functionality.md) - - A new [functionality documentation](./media/docs/cpp/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#target-architecture). - - Updates to [profiler documentation](./media/docs/cpp/profiler.md) for testing mixed input GEMM kernels on Hopper. + - [Quickstart - instantiating a Blackwell block-scaled GEMM](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#instantiating-a-blackwell-sm100-gemm-kernel). + - Detailed [Blackwell block-scaled GEMM functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html) + - A new [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. + - Updates to [compatibility](https://docs.nvidia.com/cutlass/overview.html#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](https://docs.nvidia.com/cutlass/overview.html#target-architecture). + - Updates to [profiler documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) for testing mixed input GEMM kernels on Hopper. ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). @@ -156,7 +178,7 @@ + Fix `cute::SM80_CP_ASYNC_CACHEALWAYS`, `cute::SM80_CP_ASYNC_CACHEGLOBAL`, `cute::SM80_CP_ASYNC_CACHEALWAYS_ZFILL`, `cute::SM80_CP_ASYNC_CACHEGLOBAL_ZFILL` to avoid implicitly selecting `ZFILL` behavior on predication. + Remove `cute::copy_vec` in favor of `cute::copy_aligned` and `cute::copy(AutoVectorizingCopyWithAssumedAlignment,...)`. + A refactor of default epilogue struct `DefaultEpilogue` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/default_epilogue.hpp) to avoid reading non-void `ElementC` value for `ElementC = void` kernel. -- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](./media/docs/cpp/profiler.md#cutlass-profiler). +- New CUTLASS profiler flags: `profiling-duration`, `min-iterations`, and `kernels-file` documented in [profiler.md](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#cutlass-profiler). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! - Optimal code generation with CUDA toolkit versions 12.6. @@ -168,14 +190,14 @@ + [INT8](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + [TF32](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. -- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/cpp/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [An improved mixed input GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. - [EVT nodes for Top-K selection and softmax](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](https://github.com/NVIDIA/cutlass/tree/main/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). -- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/cpp/dependent_kernel_launch.md). -- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/cpp/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- [Programmatic Dependent Launch](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html). +- [A new debugging tool, synclog](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. - A new TMA-enabled [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. - A SIMT-enabled pointer-array [epilogue](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). - A new [Ping-Pong kernel schedule for Grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. -- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/cpp/profiler.md#instantiating-more-kernels-with-hopper). +- [A new instantiation strategy for CUTLASS profiler kernels](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](https://github.com/NVIDIA/cutlass/tree/main/test/unit/core/functional.cu). - Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! @@ -198,7 +220,7 @@ - Support for residual add (beta != 0) in convolution kernels. - A new convolution [epilogue](https://github.com/NVIDIA/cutlass/tree/main/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. - A refactor of [include files throughout CUTLASS core directories](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](https://github.com/NVIDIA/cutlass/tree/main/test/self_contained_includes/CMakeLists.txt). -- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/cpp/ide_setup.md) and [expanded code style guide](./media/docs/cpp/programming_guidelines.md). +- [A guide for setting up VSCode to work well with CUTLASS](https://docs.nvidia.com/cutlass/media/docs/cpp/ide_setup.html) and [expanded code style guide](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html). - Better support for MSVC as a host compiler. - Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. - Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. @@ -206,13 +228,13 @@ ## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](https://github.com/NVIDIA/cutlass/tree/main/include/cute/atom/copy_traits_sm90_im2col.hpp) - + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/cpp/gemm_api_3x.md). + + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/conv/convnd_problem_shape.hpp). + Support for [Fprop](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](https://github.com/NVIDIA/cutlass/tree/main/test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + [CUTLASS profiler support](https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](https://github.com/NVIDIA/cutlass/tree/main/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/cpp/README.md) in CuTe and CUTLASS 3.x +- [Ampere gather/scatter convolution example](https://github.com/NVIDIA/cutlass/tree/main/examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. @@ -279,7 +301,7 @@ * Updates and bugfixes from the community (thanks!) ## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14) -* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/cpp/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python). +* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](https://github.com/NVIDIA/cutlass/tree/main/python/README.md) and new [examples](https://github.com/NVIDIA/cutlass/tree/main/examples/python). * New [efficient epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper. * Support for [fused epilogues](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. * New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. diff --git a/CMakeLists.txt b/CMakeLists.txt index f141fd40..38dcca9f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,11 +175,13 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 101 101a 120 120a) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a) endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 101f 120f) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -344,6 +346,10 @@ if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) endif() +if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f) +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED) +endif() + set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") # diff --git a/Doxyfile b/Doxyfile index 20c00ffd..9d241c0c 100644 --- a/Doxyfile +++ b/Doxyfile @@ -1056,7 +1056,7 @@ HTML_STYLESHEET = # defined cascading style sheet that is included after the standard style sheets # created by doxygen. Using this option one can overrule certain style aspects. # This is preferred over using HTML_STYLESHEET since it does not replace the -# standard style sheet and is therefore more robust against future updates. +# standard style sheet and is therefor more robust against future updates. # Doxygen will copy the style sheet file to the output directory. For an example # see the documentation. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1940,7 +1940,7 @@ PREDEFINED = EXPAND_AS_DEFINED = # If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will -# remove all references to function-like macros that are alone on a line, have an +# remove all refrences to function-like macros that are alone on a line, have an # all uppercase name, and do not end with a semicolon. Such function macros are # typically used for boiler-plate code, and will confuse the parser if not # removed. diff --git a/README.md b/README.md index 62dacaf6..6485fb05 100644 --- a/README.md +++ b/README.md @@ -40,21 +40,22 @@ designs, and bringing optimized solutions into production. CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025. To get started quickly - please refer : - - [CUTLASS C++ Quick Start Guide](./media/docs/cpp/quickstart.md). - - [CuTe DSL Quick Start Guide](./media/docs/pythonDSL/quick_start.rst). + - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). + - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html). # What's New in CUTLASS 4.0 ## CuTe DSL * CuTe DSL, a Python DSL centered around CuTe's abstractions - [Core DSL implementation files](https://github.com/NVIDIA/cutlass/tree/main/python/CuTeDSL) - - [DSL Quick Start](./media/docs/pythonDSL/quick_start.rst) - - [DSL Overview](./media/docs/pythonDSL/overview.rst) + - [DSL Quick Start](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html) + - [DSL Overview](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/overview.html) * [Overhauled documentation with an new dedicated website](https://docs.nvidia.com/cutlass) * Set of examples demonstrating how to use CuTe DSL to write peak-performance kernels - [Blackwell persistent dense GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py) - [Blackwell grouped GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py) - [Blackwell fused multi-head attention forward pass](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/fmha.py) + - [Hopper GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/hopper/dense_gemm.py) - [Ampere GEMM](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py) - [FlashAttention-2 implementation targeting Ampere and Ada class GPUs (SM80, SM86, SM89)](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) @@ -71,13 +72,15 @@ To get started quickly - please refer : - Added non-power-of-two tile sizes. - Improved performance for K-major scale factors. - The argument `mma_promotion_interval` has been removed from non-grouped GEMM to align with the grouped and Blackwell versions. +* Support LSE output in Blackwell FMHA Forward kernel. +* Improve Blackwell and Hopper grouped GEMM performance, functionality, and profiler support. * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! * Optimal code generation with CUDA toolkit versions 12.9. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. -**See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.** +**See the [CHANGELOG](https://docs.nvidia.com/cutlass/CHANGELOG.html) for details of all past releases and updates.** # Performance @@ -119,7 +122,7 @@ Layouts can also be combined and manipulated via functional composition, on whic CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design and improves code composability and readability. More documentation specific to CuTe can be found in its -[dedicated documentation directory](./media/docs/cpp/cute/00_quickstart.md). +[dedicated documentation directory](https://docs.nvidia.com/cutlass/media/docs/cpp/cute/00_quickstart.html). # Compatibility @@ -202,7 +205,7 @@ NVIDIA Blackwell GeForce RTX 50 series GPUs. As a result, kernels compiled for Blackwell SM100 architecture with arch conditional features (using `sm100a`) are not compatible with RTX 50 series GPUs. -Please refer to the [functionality documentation](./media/docs/cpp/functionality.md) +Please refer to the [functionality documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) for details on which kernels require which target architectures. # Documentation @@ -210,22 +213,22 @@ for details on which kernels require which target architectures. CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](./media/docs/cpp/quickstart.md) - basics of building and running CUTLASS -- [Functionality](./media/docs/cpp/functionality.md) - summarizes functionality available in CUTLASS -- [Efficient GEMM in CUDA](./media/docs/cpp/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA -- [CUTLASS 3.x Design](./media/docs/cpp/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components -- [GEMM API 3.x](./media/docs/cpp/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts -- [GEMM API 2.x](./media/docs/cpp/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts -- [Implicit GEMM Convolution](./media/docs/cpp/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS -- [Code Organization](./media/docs/cpp/code_organization.md) - describes the organization and contents of the CUTLASS project -- [Terminology](./media/docs/cpp/terminology.md) - describes terms used in the code -- [Programming Guidelines](./media/docs/cpp/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ -- [Fundamental types](./media/docs/cpp/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays -- [Layouts](./media/docs/cpp/layout.md) - describes layouts of matrices and tensors in memory -- [Tile Iterators](./media/docs/cpp/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory -- [CUTLASS Profiler](./media/docs/cpp/profiler.md) - command-line driven profiling application -- [CUTLASS Utilities](./media/docs/cpp/utilities.md) - additional templates used to facilitate rapid development -- [Dependent kernel launch](./media/docs/cpp/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +- [Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html) - basics of building and running CUTLASS +- [Functionality](https://docs.nvidia.com/cutlass/media/docs/cpp/functionality.html) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](https://docs.nvidia.com/cutlass/media/docs/cpp/efficient_gemm.html) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](https://docs.nvidia.com/cutlass/media/docs/cpp/cutlass_3x_design.html) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api_3x.html) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](https://docs.nvidia.com/cutlass/media/docs/cpp/gemm_api.html) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](https://docs.nvidia.com/cutlass/media/docs/cpp/implicit_gemm_convolution.html) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html) - describes the organization and contents of the CUTLASS project +- [Terminology](https://docs.nvidia.com/cutlass/media/docs/cpp/terminology.html) - describes terms used in the code +- [Programming Guidelines](https://docs.nvidia.com/cutlass/media/docs/cpp/programming_guidelines.html) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](https://docs.nvidia.com/cutlass/media/docs/cpp/fundamental_types.html) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](https://docs.nvidia.com/cutlass/media/docs/cpp/layout.html) - describes layouts of matrices and tensors in memory +- [Tile Iterators](https://docs.nvidia.com/cutlass/media/docs/cpp/tile_iterator_concept.html) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) - command-line driven profiling application +- [CUTLASS Utilities](https://docs.nvidia.com/cutlass/media/docs/cpp/utilities.html) - additional templates used to facilitate rapid development +- [Dependent kernel launch](https://docs.nvidia.com/cutlass/media/docs/cpp/dependent_kernel_launch.html) - describes a new feature in Hopper which allows overlapping dependent kernels in the same stream, and how it is used in CUTLASS. # Resources @@ -245,7 +248,7 @@ projects. Client applications should target CUTLASS's `include/` directory in th paths. CUTLASS unit tests, examples, and utilities can be build with CMake. -The minimum version of CMake is given in the [Quickstart guide](./media/docs/cpp/quickstart.md). +The minimum version of CMake is given in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -290,7 +293,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl and template concepts defined in the CUTLASS project. A detailed explanation of the source code organization may be found in the -[CUTLASS documentation](./media/docs/cpp/code_organization.md), but several main components are summarized below. +[CUTLASS documentation](https://docs.nvidia.com/cutlass/media/docs/cpp/code_organization.html), but several main components are summarized below. ## CUTLASS Template Library @@ -364,7 +367,7 @@ tools/ The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. -Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/cpp/quickstart.md). +Instructions for building and running the Unit tests are described in the [Quickstart guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). # Performance Profiling @@ -580,9 +583,9 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - - [GEMM CMake Examples](./media/docs/cpp/quickstart.md#gemm-cmake-examples) - - [Implicit GEMM convolution CMake Examples](./media/docs/cpp/quickstart.md#convolution-cmake-examples) -- [Further details about the CUTLASS Profiler are described here.](./media/docs/cpp/profiler.md) + - [GEMM CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](https://docs.nvidia.com/cutlass/media/docs/cpp/profiler.html) # About diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp index a8b220d1..4d9325b4 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/arch/grid_dependency_control.h" @@ -288,7 +287,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -445,7 +444,7 @@ struct CollectiveMma< copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; - if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { launch_dep_grids = true; cutlass::arch::launch_dependent_grids(); } @@ -453,7 +452,7 @@ struct CollectiveMma< // Advance smem_pipe_write ++smem_pipe_write; } - if (!disable_gdc && !launch_dep_grids) { + if (!disable_gdc && !launch_dep_grids) { cutlass::arch::launch_dependent_grids(); } } @@ -533,7 +532,7 @@ struct CollectiveMma< copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); ++k_tile_iter; - if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { launch_dep_grids = true; cutlass::arch::launch_dependent_grids(); } @@ -541,7 +540,7 @@ struct CollectiveMma< // Advance smem_pipe_write ++smem_pipe_write; } - if (!disable_gdc && !launch_dep_grids) { + if (!disable_gdc && !launch_dep_grids) { cutlass::arch::launch_dependent_grids(); } } @@ -634,9 +633,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -854,7 +853,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/examples/65_distributed_gemm/65_distributed_gemm.cu b/examples/65_distributed_gemm/65_distributed_gemm.cu index 06d18cef..9a7a4c30 100644 --- a/examples/65_distributed_gemm/65_distributed_gemm.cu +++ b/examples/65_distributed_gemm/65_distributed_gemm.cu @@ -133,7 +133,7 @@ using TP = _8; static constexpr int TP_ = TP{}; #if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \ - (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) // Distributed GEMM tiling/sharding schedule // Choices: @@ -252,7 +252,8 @@ HostTensorB tensor_B_arr[TP_]; HostTensorD tensor_C_arr[TP_]; HostTensorD tensor_D_arr[TP_]; -#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && + // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -345,7 +346,7 @@ struct Result { }; #if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && \ - (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -803,17 +804,18 @@ int run(Options &options) { return 0; } -#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#endif // (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && + // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example + // CUTLASS must be compiled with CUDA Toolkit 12.6 or newer to run this example // and must have compute capability at least 90. - // Some necessary cuda graph APIs were only introduced in CUDA 12.4. - if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { - std::cerr << "This example requires CUDA 12.4 or newer." << std::endl; + // Some necessary cuda graph APIs were only introduced in CUDA 12.6. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 6)) { + std::cerr << "This example requires CUDA 12.6 or newer." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } @@ -857,11 +859,11 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // -#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if (defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6))) run(options); #else std::cerr - << "This example must be compiled with `sm90a` and CUDA Toolkit 12.4 or later." << std::endl; + << "This example must be compiled with `sm90a` and CUDA Toolkit 12.6 or later." << std::endl; return 0; #endif diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu index d14360de..3db531ec 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu @@ -250,8 +250,6 @@ cutlass::DeviceAllocation block_beta; /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// -using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams>::RasterOrderOptions; - /// Result structure struct Result { @@ -518,7 +516,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; } - arguments.scheduler.raster_order = options.raster; + arguments.scheduler.raster_order = options.raster_order; // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) arguments.scheduler.max_swizzle_size = options.swizzle; @@ -690,10 +688,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true) std::string raster = "Heuristic"; - if (options.raster == RasterOrderOptions::AlongN) { + if (options.raster_order == RasterOrderOptions::AlongN) { raster = "Along N"; } - else if (options.raster == RasterOrderOptions::AlongM) { + else if (options.raster_order == RasterOrderOptions::AlongM) { raster = "Along M"; } @@ -747,7 +745,7 @@ int main(int argc, char const **args) { // Parse options // - Options options; + Options options; options.parse(argc, args); diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu index 2ea42bbf..1977b698 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu @@ -253,8 +253,6 @@ cutlass::DeviceAllocation block_beta; /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// -using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams>::RasterOrderOptions; - /// Result structure struct Result { @@ -523,7 +521,7 @@ GemmArguments args_from_options(const OptionType &options, bool host_problem_sha fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; } - arguments.scheduler.raster_order = options.raster; + arguments.scheduler.raster_order = options.raster_order; // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) arguments.scheduler.max_swizzle_size = options.swizzle; @@ -699,10 +697,10 @@ int run(OptionType &options, bool host_problem_shapes_available = true) std::string raster = "Heuristic"; - if (options.raster == RasterOrderOptions::AlongN) { + if (options.raster_order == RasterOrderOptions::AlongN) { raster = "Along N"; } - else if (options.raster == RasterOrderOptions::AlongM) { + else if (options.raster_order == RasterOrderOptions::AlongM) { raster = "Along M"; } @@ -755,7 +753,7 @@ int main(int argc, char const **args) { // Parse options // - Options options; + Options options; options.parse(argc, args); diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index 19497176..dacbb324 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -30,10 +30,9 @@ **************************************************************************************************/ // Command line options parsing -template +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +template struct Options { - - using RasterOrderOptions = _RasterOrderOptions; using ProblemShape = _ProblemShape; bool help = false; @@ -50,7 +49,7 @@ struct Options { int const m_alignment = 128; int const n_alignment = 128; - RasterOrderOptions raster; + RasterOrderOptions raster_order; int swizzle; // Parses the command line @@ -74,13 +73,13 @@ struct Options { cmd.get_cmd_line_argument("raster", raster_char); if (raster_char == 'N' || raster_char == 'n') { - raster = RasterOrderOptions::AlongN; + raster_order = RasterOrderOptions::AlongN; } else if (raster_char == 'M' || raster_char == 'm') { - raster = RasterOrderOptions::AlongM; + raster_order = RasterOrderOptions::AlongM; } else if (raster_char == 'H' || raster_char == 'h') { - raster = RasterOrderOptions::Heuristic; + raster_order = RasterOrderOptions::Heuristic; } cmd.get_cmd_line_argument("swizzle", swizzle, 1); diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu index 8f6def99..c6e1d753 100644 --- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -543,7 +543,7 @@ int run(Options &options) { int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + // CUTLASS must be compiled with CUDA 12.8 Toolkit or newer to run this example // and must have compute capability at least 100. if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; @@ -560,7 +560,6 @@ int main(int argc, char const **args) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; return 0; } - // // Parse options // diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu index f9d5e842..097a8693 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu @@ -237,7 +237,7 @@ cutlass::DeviceAllocation block_beta; /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// -using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; // Command line options parsing struct Options { diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index f052b5f2..50d37945 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -300,7 +300,7 @@ auto make_iterator(T* ptr) { /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// -using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; // Command line options parsing struct Options { diff --git a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu index daadcd56..67313926 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu @@ -490,7 +490,7 @@ int run(Options &options) int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example // and must have compute capability at least 90. if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; @@ -503,11 +503,11 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; return 0; - } - + } // // Parse options // diff --git a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu index 8598637e..b1d7bc15 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu @@ -490,7 +490,7 @@ int run(Options &options) int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example // and must have compute capability at least 90. if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; @@ -503,11 +503,11 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; return 0; - } - + } // // Parse options // diff --git a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu index d99cdacc..abac47ae 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu @@ -499,11 +499,11 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; return 0; } - // // Parse options // diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index c8792122..8fc74c1f 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -117,15 +117,17 @@ struct Options { int q = 256; int k = 256; int d = 128; + int warmup_iterations = 1; int iterations = 3; + int tensor_ring_buffers = 1; bool verify = false; bool verbose = false; bool causal = false; bool residual = false; bool varlen = false; + bool persistent = false; int sm_count = 0; - std::string kernel_filter; InitStyle init_style_q = InitStyle::kRandom; @@ -189,10 +191,15 @@ struct Options { if (b == -1) b = 16384 / k; if (b == 0) b = 1; + cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, defaults.warmup_iterations); cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + cmd.get_cmd_line_argument("tensor_ring_buffers", tensor_ring_buffers, defaults.tensor_ring_buffers); + verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); varlen = cmd.check_cmd_line_flag("varlen"); + persistent = cmd.check_cmd_line_flag("persistent"); + std::string mask; cmd.get_cmd_line_argument("mask", mask, ""); if (mask == "no" || mask == "") { @@ -210,7 +217,6 @@ struct Options { causal = false; } cmd.get_cmd_line_argument("sm-count", sm_count, defaults.sm_count); - get_init_style_argument(cmd, "init-style", init_style_q, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_k, defaults.init_style_q); get_init_style_argument(cmd, "init-style", init_style_v, defaults.init_style_q); @@ -235,10 +241,13 @@ struct Options { << " --q= Sets the Q extent\n" << " --k= Sets the K extent\n" << " --d= Sets the D extentn" + << " --tensor_ring_buffers= Sets the number of tensor ring buffers\n" + << " --warmup_iterations= Sets the warmup iterations\n" << " --iterations= Benchmarking iterations\n" << " --verify Verify results\n" << " --verbose Print smem and execution time per kernel\n" << " --mask= Enables masking\n" + << " --persistent Enables persistent scheduler\n" << " --varlen Enables variable sequence length\n" << " B*Q and B*K become the total sequence length\n" << " and are split B-ways, alternatingly +10% and -10%\n" @@ -379,40 +388,55 @@ struct FwdRunner { StrideLSE stride_LSE; uint64_t seed = 0; - DeviceAllocation block_Q; - DeviceAllocation block_K; - DeviceAllocation block_V; - DeviceAllocation block_O; - DeviceAllocation block_LSE; - DeviceAllocation block_ref_O; - DeviceAllocation block_ref_LSE; + struct DeviceBuffer { + DeviceAllocation block_Q; + DeviceAllocation block_K; + DeviceAllocation block_V; + DeviceAllocation block_O; + DeviceAllocation block_LSE; + DeviceAllocation block_ref_O; + DeviceAllocation block_ref_LSE; + DeviceAllocation device_cumulative_seqlen_q; + DeviceAllocation device_cumulative_seqlen_kv; + + DeviceBuffer() = default; + DeviceBuffer(const DeviceBuffer&) = delete; + DeviceBuffer& operator=(const DeviceBuffer&) = delete; + + size_t get_storage_size() const { + return block_Q.get_storage_size() + block_K.get_storage_size() + block_V.get_storage_size() + + block_O.get_storage_size() + block_LSE.get_storage_size() + block_ref_O.get_storage_size() + + block_ref_LSE.get_storage_size() + device_cumulative_seqlen_q.get_storage_size() + + device_cumulative_seqlen_kv.get_storage_size(); + } + }; + + std::vector> buffers; std::vector cumulative_seqlen_q; std::vector cumulative_seqlen_kv; - DeviceAllocation device_cumulative_seqlen_q; - DeviceAllocation device_cumulative_seqlen_kv; // // Methods // - bool verify(const ProblemShapeType& problem_shape) { - Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + bool verify(const ProblemShapeType& problem_shape, DeviceBuffer& buffer) { + Tensor mQ = make_tensor(make_gmem_ptr(buffer.block_Q.get()), select<0,2,3>(problem_shape), stride_Q); - Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + Tensor mK = make_tensor(make_gmem_ptr(buffer.block_K.get()), select<1,2,3>(problem_shape), stride_K); - Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + Tensor mV = make_tensor(make_gmem_ptr(buffer.block_V.get()), select<1,2,3>(problem_shape), stride_V); - Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), + Tensor mO = make_tensor(make_gmem_ptr(buffer.block_ref_O.get()), select<0,2,3>(problem_shape), stride_O); - Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), + Tensor mLSE = make_tensor(make_gmem_ptr(buffer.block_ref_LSE.get()), select<0,3>(problem_shape), stride_LSE); @@ -431,7 +455,7 @@ struct FwdRunner { // Check if output from CUTLASS kernel and reference kernel are equal or not double max_diff = 0; double mean_diff = 0; - reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); + reference_abs_diff(buffer.block_O, buffer.block_ref_O, max_diff, mean_diff); bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); if (! passed_O) { @@ -439,14 +463,13 @@ struct FwdRunner { << " mean " << mean_diff << std::endl; } - // reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); + reference_abs_diff(buffer.block_LSE, buffer.block_ref_LSE, max_diff, mean_diff); - bool passed_LSE = true; // future work - // bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); - // if ( ! passed_LSE) { - // std::cerr << "failed LSE: max diff " << max_diff - // << " mean " << mean_diff << std::endl; - // } + bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_LSE) { + std::cerr << "failed LSE: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } return passed_O && passed_LSE; } @@ -559,50 +582,70 @@ struct FwdRunner { get<1,1>(stride_LSE) = 0; } - block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); - block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); - block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); - block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); - block_LSE.reset(size(shape_LSE)); - block_ref_O.reset(size(shape_QO)); - block_ref_LSE.reset(size(shape_LSE)); + auto buffer_init_fn = [&](auto& buffer) { + buffer.block_Q.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + buffer.block_K.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); + buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + buffer.block_LSE.reset(size(shape_LSE)); - initialize_block(block_Q, seed + 2023, options.init_style_q); - initialize_block(block_K, seed + 2022, options.init_style_k); - initialize_block(block_V, seed + 2021, options.init_style_v); + initialize_block(buffer.block_Q, seed + 2023, options.init_style_q); + initialize_block(buffer.block_K, seed + 2022, options.init_style_k); + initialize_block(buffer.block_V, seed + 2021, options.init_style_v); - if ( ! cumulative_seqlen_q.empty()) { - device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); - device_cumulative_seqlen_q.copy_from_host( - cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); - } - if ( ! cumulative_seqlen_kv.empty()) { - device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); - device_cumulative_seqlen_kv.copy_from_host( - cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + if ( ! cumulative_seqlen_q.empty()) { + buffer.device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + buffer.device_cumulative_seqlen_q.copy_from_host( + cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + if ( ! cumulative_seqlen_kv.empty()) { + buffer.device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + buffer.device_cumulative_seqlen_kv.copy_from_host( + cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + } + }; + + buffers.push_back(std::make_unique()); + buffer_init_fn(*buffers.back()); + + int tensor_ring_buffers = options.tensor_ring_buffers; + for (int i = 1; i < tensor_ring_buffers; i++) { + buffers.push_back(std::make_unique()); + buffer_init_fn(*buffers.back()); } if constexpr (kIsVarlen) { - get<0>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); - get<1>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); + get<0>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_q.get(); + get<1>(problem_shape).cumulative_length = buffers[0]->device_cumulative_seqlen_kv.get(); } return problem_shape; } + auto get_arguments(const ProblemShapeType& problem_shape, const cutlass::KernelHardwareInfo& hw_info, int buffer_index) { + auto problem_shape_ = problem_shape; + if constexpr (kIsVarlen) { + get<0>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_q.get(); + get<1>(problem_shape_).cumulative_length = buffers[buffer_index]->device_cumulative_seqlen_kv.get(); + } + typename Operation::Arguments arguments{ + problem_shape_, + { buffers[buffer_index]->block_Q.get(), stride_Q, + buffers[buffer_index]->block_K.get(), stride_K, + buffers[buffer_index]->block_V.get(), stride_V }, + { buffers[buffer_index]->block_O.get(), stride_O, + buffers[buffer_index]->block_LSE.get(), stride_LSE }, + hw_info + }; + return arguments; + } + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { ProblemShapeType problem_shape = initialize(options); - typename Operation::Arguments arguments{ - problem_shape, - { block_Q.get(), stride_Q, - block_K.get(), stride_K, - block_V.get(), stride_V }, - { block_O.get(), stride_O, - block_LSE.get(), stride_LSE }, - hw_info - }; + int buffer_index = 0; + typename Operation::Arguments arguments = get_arguments(problem_shape, hw_info, buffer_index); Operation op; @@ -630,11 +673,21 @@ struct FwdRunner { } // Run - status = op.run(); - if (status != cutlass::Status::kSuccess) { - std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " - << cudaGetErrorString(cudaGetLastError()) << std::endl; - return example_result; + for (int i = 0; i < options.warmup_iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + buffer_index = (buffer_index + 1) % buffers.size(); + arguments = get_arguments(problem_shape, hw_info, buffer_index); + status = op.update(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " + << std::endl; + return example_result; + } } cudaError_t result = cudaDeviceSynchronize(); @@ -672,6 +725,14 @@ struct FwdRunner { << cudaGetErrorString(cudaGetLastError()) << std::endl; return example_result; } + buffer_index = (buffer_index + 1) % buffers.size(); + arguments = get_arguments(problem_shape, hw_info, buffer_index); + status = op.update(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to update the CUTLASS kernel's parameters. Last CUDA error is: " + << std::endl; + return example_result; + } } // @@ -734,10 +795,10 @@ struct FwdRunner { // Verify that the result is correct bool passed = true; if (options.verify) { - passed = verify(problem_shape); + passed = verify(problem_shape, *buffers[0]); if (passed) example_result.verified = true; } - + if (!passed) { std::cerr << "Reference check failed" << std::endl; return example_result; @@ -789,10 +850,14 @@ void run_fwd_128(Mask fusion, Options const & options, cutlass::KernelHardwareIn using HeadDim = _128; - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + if (options.persistent) { + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + } + else { + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + } } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -818,10 +883,14 @@ void run_fwd_64(Mask fusion, Options const & options, cutlass::KernelHardwareInf using HeadDim = _64; - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + if (options.persistent) { + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + } + else { + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + } } @@ -845,10 +914,14 @@ void run_fwd_32(Mask fusion, Options const & options, cutlass::KernelHardwareInf using HeadDim = _32; #ifdef FP8 - // Persistent Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); - // Individual Tile Scheduler - run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + if (options.persistent) { + // Persistent Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 persistent", Option{}); + } + else { + // Individual Tile Scheduler + run(Shape<_256, _128, HeadDim>{}, "tma ws 256x128 acc fp32 individual", Option{}); + } #endif } diff --git a/examples/77_blackwell_fmha/77_blackwell_mla.cu b/examples/77_blackwell_fmha/77_blackwell_mla.cu index baa70fce..ca024623 100644 --- a/examples/77_blackwell_fmha/77_blackwell_mla.cu +++ b/examples/77_blackwell_fmha/77_blackwell_mla.cu @@ -59,7 +59,7 @@ using namespace cutlass::fmha::kernel; /////////////////////////////////////////////////////////////////////////////////////////////////// enum class InitStyle { - kOne, kLinearStride128, kLinearStride1, kRandom, kNone + kOne, kLinearStride128, kLinearStride1, kRandom, kRandomLarge, kNone }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -98,6 +98,9 @@ struct Options { if (s == "r") { dst = InitStyle::kRandom; } + else if (s == "l") { + dst = InitStyle::kRandomLarge; + } else if (s == "1") { dst = InitStyle::kOne; } @@ -203,6 +206,11 @@ void initialize_block( block.get(), block.size(), seed, (Element) -1, (Element) 1); break; } + case InitStyle::kRandomLarge: { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) -1, (Element) 100); + break; + } case InitStyle::kLinearStride1: { std::vector data(block.size()); for (size_t i = 0; i < block.size() / 128; i ++) { diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index f04ebe41..8a30510b 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -144,4 +144,23 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC target_compile_definitions(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE ${PREC_MACRO} SKIP_ATOMIC) target_compile_options(77_blackwell_fmha_bwd_sat_${PREC} PRIVATE -Xptxas -v) endforeach() + + # Add a target that builds all examples + add_custom_target(77_blackwell_fmha_all + DEPENDS + 77_blackwell_fmha_fp8 + 77_blackwell_fmha_fp16 + 77_blackwell_fmha_gen_fp8 + 77_blackwell_fmha_gen_fp16 + 77_blackwell_mla_2sm_fp8 + 77_blackwell_mla_2sm_fp16 + 77_blackwell_mla_2sm_cpasync_fp8 + 77_blackwell_mla_2sm_cpasync_fp16 + 77_blackwell_mla_b2b_2sm_fp8 + 77_blackwell_mla_b2b_2sm_fp16 + 77_blackwell_fmha_bwd_fp8 + 77_blackwell_fmha_bwd_fp16 + 77_blackwell_fmha_bwd_sat_fp8 + 77_blackwell_fmha_bwd_sat_fp16 + ) endif() diff --git a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp index f31c8024..6478b5d5 100644 --- a/examples/77_blackwell_fmha/collective/fmha_fusion.hpp +++ b/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -157,8 +157,8 @@ struct CausalMask : NoMask { TileShape const& tile_shape, ProblemSize const& problem_size) { - int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); - return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); } template diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp index 82400801..2740c6b8 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -42,7 +42,7 @@ template< class ElementAcc, class TileShape, // Q, D, _ class StrideO, // Q, D, B - class StrideLSE // Q, B + class StrideLSE_ // Q, B > struct Sm100FmhaFwdEpilogueTmaWarpspecialized { @@ -54,6 +54,7 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { // using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; struct TensorStorage { @@ -79,6 +80,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { struct Params { TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; }; template @@ -110,7 +114,9 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { ); return { - tma_store_o + tma_store_o, + args.ptr_LSE, + args.dLSE }; } @@ -119,6 +125,10 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); } + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + template CUTLASS_DEVICE auto store( diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 1eaea0ce..58102767 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -531,7 +531,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem @@ -613,7 +613,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { NumericArrayConverter convert; const int kReleasePipeCount = 10; // must be multiple of 2 - + order_s.wait(); CUTLASS_PRAGMA_UNROLL @@ -637,7 +637,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); - + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { order_s.arrive(); } @@ -691,7 +691,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; - + row_sum = local_row_sum; if (final_call) { @@ -787,14 +787,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32 / sizeof(ElementOut); - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOsO = mma.get_slice(0).partition_C(sO); - + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); @@ -809,7 +809,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); - + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); @@ -824,9 +824,9 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); - + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); - + #ifndef ONLY_SOFTMAX CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size(tTMrO); j += 2) { @@ -872,24 +872,24 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 16; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; - + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); - + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); @@ -899,7 +899,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { float2 scale_f32x2 = make_float2(scale, scale); Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); - + auto copy_in = [&](int i) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); @@ -942,7 +942,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } } - template + template CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, @@ -951,7 +951,8 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, - PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state) { + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); @@ -961,7 +962,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); @@ -1060,13 +1061,25 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { // F2FP // store to smem Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), repeat_like(typename CollectiveEpilogue::StrideLSE{}, _1{}), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx, get<2>(blk_coord)) = lse; + } + } + cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); ++pipeline_o_consumer_state; - + pipeline_epi.producer_commit(pipeline_epi_producer_state); ++pipeline_epi_producer_state; @@ -1083,6 +1096,16 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx, get<2>(blk_coord)) = lse; + } + } + cutlass::arch::fence_view_async_tmem_load(); pipeline_o.consumer_release(pipeline_o_consumer_state); diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp index e1bd43d5..8c5401a9 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -118,7 +118,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { using TensorStrideContiguousK = Stride>; using TensorStrideContiguousMN = Stride<_1, int, Stride>; - + // compute S using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, @@ -381,7 +381,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q, D, HB), args.mainloop.stride_dq_acc), SmemLayoutDQ{}(_, _, _0{}) ); - + return Params{ args.problem_shape, args.mainloop, @@ -452,7 +452,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); - + auto tSTgK = cta_mma_kq.partition_A(gK); auto tSTgQ = cta_mma_kq.partition_B(gQ); auto tDPTgV = cta_mma_vdo.partition_A(gV); @@ -477,7 +477,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); // set up lse and sum_odo - + auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); @@ -495,7 +495,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } // load Q - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -520,7 +520,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { &mLSE(gmem_idx, blk_coord_batch), gmem_idx < Q ); - + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; @@ -529,7 +529,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); - + // load V if (cute::elect_one_sync()) { cute::copy( @@ -540,7 +540,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } // load dO - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -573,7 +573,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); // load Q - if (cute::elect_one_sync()) { + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -584,7 +584,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ++pipeline_load_mma_q_producer_state; pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); - + // load LSE smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; @@ -593,15 +593,15 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { &mLSE(gmem_idx, blk_coord_batch), gmem_idx < Q ); - + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_lse_producer_state; pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); - // load dO - if (cute::elect_one_sync()) { + // load dO + if (cute::elect_one_sync()) { cute::copy( mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), @@ -612,7 +612,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ++pipeline_load_mma_do_producer_state; pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); - + // load sum_OdO smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; @@ -621,7 +621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { &mSumOdO(gmem_idx, blk_coord_batch), gmem_idx < Q ); - + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_load_compute_sum_odo_producer_state; @@ -639,23 +639,23 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { int iter_count, MainloopArguments const& mainloop_args, TensorStorage& shared_tensors, - PipelineLoadMmaQ& pipeline_load_mma_q, - typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, - PipelineLoadMmaDO& pipeline_load_mma_do, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, - PipelineMmaComputeS& pipeline_mma_compute_s, + PipelineMmaComputeS& pipeline_mma_compute_s, typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, - PipelineMmaComputeDP& pipeline_mma_compute_dp, + PipelineMmaComputeDP& pipeline_mma_compute_dp, typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, - PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, - PipelineComputeMmaP& pipeline_compute_mma_p, + PipelineComputeMmaP& pipeline_compute_mma_p, typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, - PipelineComputeMmaDS& pipeline_compute_mma_ds, + PipelineComputeMmaDS& pipeline_compute_mma_ds, typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { - + auto [Q, K, D, HB] = problem_shape; auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); @@ -685,7 +685,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); tDVrP.data() = TmemAllocation::kP; Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); - + TiledMmaKQ tiled_mma_kq; TiledMmaVDO tiled_mma_vdo; TiledMmaDSK tiled_mma_dsk; @@ -923,6 +923,8 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { TensorC const& coord, TensorShape const& tensor_shape) { + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + auto copy_op = make_cotiled_copy( Copy_Atom, Element>{}, make_layout(make_shape(_1{}, Int{})), @@ -930,21 +932,11 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { ); auto thr_copy = copy_op.get_slice(_0{}); - auto tCg = thr_copy.partition_D(gmem); - auto tCr = thr_copy.partition_S(quantize(regs)); - auto tCc = thr_copy.partition_D(coord); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tCr = thr_copy.partition_S(quantize(regs)); + Tensor tPc = thr_copy.partition_D(preds); - constexpr int R = decltype(tCr.layout())::rank; - auto tCg_v = group_modes<1, R>(tCg); - auto tCr_v = group_modes<1, R>(tCr); - auto tCc_v = group_modes<1, R>(tCc); - auto tCp_v = make_tensor(shape<1>(tCc_v)); - - for (int i = 0; i < size(tCp_v); ++i) { - tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape); - } - - copy_if(copy_op, tCp_v, tCr_v, tCg_v); + copy_if(copy_op, tPc, tCr, tCg); } @@ -1073,7 +1065,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { - + auto [Q, K, D, HB] = problem_shape; // in tmem, S & P overlap @@ -1114,7 +1106,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { Tensor tTR_cST = split_wg(thread_t2r.partition_D(cST)); Tensor tTR_rST = make_tensor(shape(tTR_cST)); Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); - + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); Tensor tTR_cDPT = split_wg(tTR_cDPT_p); Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); @@ -1152,20 +1144,20 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { fn(cute::false_type{}); } }; - + dispatch_bool(std::is_base_of_v && warp_uniform(iter_index == get<1>(blk_coord)), [&](auto is_causal_masked_tile) { // compute P = softmax(S, LSE) cute::copy(tiled_t2r, tTR_tST, tTR_rST); - + if constexpr (std::is_base_of_v && decltype(is_causal_masked_tile)::value) { Mask{}.apply_mask(tTR_rST, [&](int i) { auto c_transpose = tTR_cST(i); return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); }, problem_shape); } - + ElementAcc log2_e = static_cast(M_LOG2E); float2 softmax_scale_log2_e; softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; @@ -1184,16 +1176,16 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tTR_rST(i) = ::exp2f(out.x); tTR_rST(i+1) = ::exp2f(out.y); } - + auto tRT_rST = quantize(tTR_rST); auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); - + cutlass::arch::fence_view_async_tmem_load(); cutlass::arch::NamedBarrier( kNumComputeWarps * NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::TransformBarrier ).arrive_and_wait(); - + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); }); @@ -1293,9 +1285,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, PipelineReduceTmaStore& pipeline_reduce_tma_store, typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { - + using X = Underscore; - + auto [Q, K, D, HB] = problem_shape; auto [blk_coord_q, blk_coord_k, blk_coord_batch] = blk_coord; @@ -1307,7 +1299,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tDQtDQ.data() = TmemAllocation::kDQ; Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); - auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) (_, _, _, _0{}, blk_coord_batch); Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); @@ -1376,7 +1368,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_index += 1; } } - + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { int warp_idx = cutlass::canonical_warp_idx_sync(); @@ -1561,7 +1553,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; - + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); auto pipeline_load_mma_do_producer_state = make_producer_start_state(); auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); @@ -1587,7 +1579,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { if (role == WarpRole::Load) { warpgroup_reg_set(); - + load( blk_coord, problem_shape, @@ -1596,7 +1588,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { params.mainloop, params.mainloop_params, shared_storage.tensors, - pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, pipeline_load_mma_do, pipeline_load_mma_do_producer_state, pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state @@ -1608,7 +1600,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); __syncwarp(); - + mma( blk_coord, problem_shape, @@ -1616,7 +1608,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { iter_count, params.mainloop, shared_storage.tensors, - pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, @@ -1629,7 +1621,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if (role == WarpRole::Compute) { warpgroup_reg_set(); - + compute( blk_coord, problem_shape, @@ -1660,7 +1652,7 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else if (role == WarpRole::Reduce) { warpgroup_reg_set(); - + reduce( blk_coord, problem_shape, @@ -1677,9 +1669,9 @@ struct Sm100FmhaBwdKernelTmaWarpSpecialized { } else { warpgroup_reg_set(); - + /* no-op */ - + } } diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index fbb8d362..e297e731 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -356,7 +356,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); CollectiveMainloop mainloop; - CollectiveEpilogue epilogue; + CollectiveEpilogue epilogue{params.epilogue}; if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { warpgroup_reg_set(); @@ -407,7 +407,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, pipeline_s1_corr_consumer_state, pipeline_mma_corr, pipeline_mma_corr_consumer_state, - pipeline_corr_epi, pipeline_corr_epi_producer_state + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue ); diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp index c6a05750..98f40ce8 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -146,7 +146,7 @@ struct Sm100FmhaMlaReductionKernel { ElementAcc sum_lse = 0; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kNLsePerThread; ++i) { - sum_lse = sum_lse + expf(local_lse[i] - params.scale * lse_max); + sum_lse = sum_lse + expf(local_lse[i] - lse_max); } CUTLASS_PRAGMA_UNROLL @@ -156,7 +156,7 @@ struct Sm100FmhaMlaReductionKernel { sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); - ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + params.scale * lse_max; + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; if (threadIdx.x == 0 and params.ptr_lse != nullptr) { gLSE(0) = global_lse; } diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp index b7c6b412..68718c6b 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -127,7 +127,7 @@ void __global__ fmha_reference_kernel( mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); } - if (threadIdx.x == 0) { + if (threadIdx.x == 0 && mLSE.data() != nullptr) { mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; } diff --git a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp index 6d833ad1..a4d4b262 100644 --- a/examples/77_blackwell_fmha/reference/reference_abs_error.hpp +++ b/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -75,6 +75,8 @@ struct DeviceAllocation { size_t size() const { return size_; } + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + void copy_from_host(const T* ptr, size_t sz) { auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); assert(ret == cudaSuccess); diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu index c86580db..2b55c465 100644 --- a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -280,7 +280,7 @@ auto make_iterator(T* ptr) { /// Testbed utility types ///////////////////////////////////////////////////////////////////////////////////////////////// -using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm100GroupParams::RasterOrderOptions; +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; // Command line options parsing struct Options { diff --git a/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu index 573c25cb..acac2576 100644 --- a/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu +++ b/examples/82_blackwell_distributed_gemm/82_blackwell_distributed_gemm.cu @@ -133,7 +133,7 @@ using TP = _8; static constexpr int TP_ = TP{}; #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \ - (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) // Distributed GEMM tiling/sharding schedule // Choices: @@ -254,7 +254,8 @@ HostTensorB tensor_B_arr[TP_]; HostTensorD tensor_C_arr[TP_]; HostTensorD tensor_D_arr[TP_]; -#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && + // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) ///////////////////////////////////////////////////////////////////////////////////////////////// /// Testbed utility types @@ -347,7 +348,7 @@ struct Result { }; #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && \ - (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) + (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -805,17 +806,16 @@ int run(Options &options) { return 0; } -#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#endif // (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && + // (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA Toolkit 12.4 or newer to run this example - // and must have compute capability at least 90. - // Some necessary cuda graph APIs were only introduced in CUDA 12.4. - if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { - std::cerr << "This example requires CUDA 12.4 or newer." << std::endl; + // CUTLASS must be compiled with CUDA Toolkit 12.8 or newer to run Blackwell kernels. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } @@ -861,11 +861,11 @@ int main(int argc, char const **args) { // Evaluate CUTLASS kernels // -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8))) run(options); #else std::cerr - << "This example must be compiled with `sm100a` and CUDA Toolkit 12.4 or later." << std::endl; + << "This example must be compiled with `sm100a` and CUDA Toolkit 12.8 or later." << std::endl; return 0; #endif diff --git a/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md index 3943716b..996956a4 100644 --- a/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md +++ b/examples/82_blackwell_distributed_gemm/REQUIREMENTS.md @@ -14,8 +14,8 @@ cmake $PATH -DCUTLASS_NVCC_ARCHS="100a" -DCUTLASS_ENABLE_GDC_FOR_SM100=1 ### Minimum software Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required. -This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary -CUDA graph APIs. +This example specifically requires CUDA Toolkit 12.8 or newer, since that is the first version +supporting the Blackwell architecture. ### Hardware / driver settings diff --git a/examples/88_hopper_fmha/88_hopper_fmha.cu b/examples/88_hopper_fmha/88_hopper_fmha.cu new file mode 100644 index 00000000..7c09d7d3 --- /dev/null +++ b/examples/88_hopper_fmha/88_hopper_fmha.cu @@ -0,0 +1,1192 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example implementation of fused multi-head attention for Hopper using CUTLASS 3. + + This example showcases the use of CUTLASS to build forward and backward fused + multi-head attention (FMHA) collectives from existing CUTLASS collectives targeting + the NVIDIA Hopper architecture. + + Background and motivation + ------------------------- + CUTLASS is a highly flexible library that provides open-source building blocks + for tensor core programming for GEMM or GEMM-like problems. Fused multi-head + attention (FMHA) is a foundational kernel for large language models (LLMs) since it + makes long sequence lengths feasible from a memory-usage perspective. It also + improves computational efficiency since it transforms an outer-product-like and + a matrix-vector-like GEMM into a fused operation with much higher arithmetic + intensity. For more details, see Dao et al, 2022; Dao, 2023. + Implementing this kernel in CUTLASS enabled easy customization and high + performance. + + Introduction + ------------ + The example targets the NVIDIA Hopper architecture, and takes advantage of + warpgroup-wide tensor cores, the Tensor Memory Accelerator (TMA), just like + GEMMs do. It provides both a forward and a backward pass (often abbreviated + fwd and bwd in the code), and an optional FP8 mode for the forward pass. + The code is structured into four layers: The runner (and the reference kernels) + takes care of initialization, measurement, and testing; the device layer + orchestrates kernel calls and partitions workspace; the kernel layer (just + like the CUTLASS kernel layer); and the collective layer (most of the logic + of FMHA is implemented here). + + Details + ------- + This example contains a considerable amount of code. For a more detailed + look at it, please refer to the README.md. + + Example usage: + $ ./examples/88_hopper_fmha/88_hopper_fmha \ + --b=2048 --h=2048 --d=2048 --q=2048 --k=2048 +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "collective/fmha_fusion.hpp" +#include "device/fmha_device_bwd.hpp" +#include "device/device_universal.hpp" +#include "kernel/fmha_kernel_builder.hpp" +#include "reference/fmha_reference.hpp" +#include "reference/fmha_bwd_reference.hpp" +#include "reference/reference_abs_error.hpp" + +using namespace cute; +using namespace cutlass::fmha::kernel; +using namespace cutlass::fmha::collective; +using namespace cutlass::fmha; + +// Uncomment for FP8 +// #define FP8 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + + int b, h, q, k, d; + int iterations; + bool verify; + bool verbose; + bool causal; + bool residual; + bool bwd; + + Options(): + help(false), + error(false), + b(16), h(16), q(1024), k(1024), d(128), + iterations(3), verify(false), + causal(false), residual(false), bwd(false), verbose(false) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + Options defaults; + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("d", d, defaults.d); + cmd.get_cmd_line_argument("h", h, -1); + if (h == -1) h = 2048 / d; + + cmd.get_cmd_line_argument("q", q, -1); + cmd.get_cmd_line_argument("k", k, -1); + if (q == -1) q = k; + if (k == -1) k = q; + if (q == -1 && k == -1) q = k = defaults.q; + + cmd.get_cmd_line_argument("b", b, -1); + if (b == -1) b = 16384 / k; + if (b == 0) b = 1; + + cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations); + verify = cmd.check_cmd_line_flag("verify"); + verbose = cmd.check_cmd_line_flag("verbose"); + + std::string mask; + cmd.get_cmd_line_argument("mask", mask, ""); + if (mask == "no" || mask == "") { + causal = residual = false; + } + else if (mask == "causal") { + residual = false; + causal = true; + } + else if (mask == "residual") { + residual = true; + causal = false; + } + + bwd = cmd.check_cmd_line_flag("bwd"); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "88_hopper_fmha\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " fused multi-head attention forward-pass kernels targeting NVIDIA's Hopper architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --b= Sets the B extent\n" + << " --h= Sets the H extent\n" + << " --q= Sets the Q extent\n" + << " --k= Sets the K extent\n" + << " --d= Sets the D extent\n" + << " --iterations= Benchmarking iterations\n" + << " --verify Verify results\n" + << " --verbose Print smem and execution time per kernel\n" + << " --mask= Enables masking\n" + << " --bwd Runs the backwards pass\n" + << "\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023, bool init_one=false) { + + if (init_one) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) 1, (Element) 1); + } else { + cutlass::reference::device::BlockFillRandomGaussian( + block.get(), block.size(), seed, (Element) 0, (Element) 1); + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ExampleResult { + bool passed = false; + bool verified = false; + float runtime_ms = 0; + double tflops_s = 0; + size_t smem_size = 0; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class DispatchPolicy, + class ActiveFusion, + class... KernelOptions +> +struct FwdRunner { + +#ifdef FP8 + using Element = cutlass::float_e4m3_t; + using ElementAccumulatorQK = find_option_t; +#else + using Element = cutlass::half_t; + using ElementAccumulatorQK = float; +#endif + + using ElementAccumulatorPV = float; + + // B H Q K D + using ProblemShapeType = cute::tuple; + + + using StrideQ = cute::tuple>; // Q D (B H) + using StrideK = cute::tuple>; // K D (B H) + using StrideV = std::conditional_t>, + cute::tuple>>; // K D (B H) + using StrideO = cute::tuple>; // Q D (B H) + using StrideLSE = cute::tuple<_1, cute::tuple>; // Q (B H) + + using Operation = cutlass::device::Universal< + typename cutlass::fmha::kernel::FmhaBuilder< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, StrideQ, StrideK, StrideV, + ActiveFusion, DispatchPolicy, KernelOptions... + >::Kernel>; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_LSE; + cutlass::DeviceAllocation block_ref_O; + cutlass::DeviceAllocation block_ref_LSE; + + // + // Methods + // + bool verify(const ProblemShapeType& problem_size) { + auto [B, H, Q, K, D] = problem_size; + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + make_shape(Q, D, make_shape(B, H)), + stride_Q); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + make_shape(K, D, make_shape(B, H)), + stride_K); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + make_shape(K, D, make_shape(B, H)), + stride_V); + + Tensor mO = make_tensor(make_gmem_ptr(block_ref_O.get()), + make_shape(Q, D, make_shape(B, H)), + stride_O); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_ref_LSE.get()), + make_shape(Q, make_shape(B, H)), + stride_LSE); + + fmha_reference(problem_size, mQ, mK, mV, mO, mLSE, ActiveFusion{}); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + const double kMaxDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-2; + const double kMeanDiffThresh = sizeof(Element) == 1 ? 1e-1 : 1e-3; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_O, block_ref_O, max_diff, mean_diff); + bool passed_O = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if (! passed_O) { + std::cerr << "failed O: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_LSE, block_ref_LSE, max_diff, mean_diff); + bool passed_LSE = (max_diff < kMaxDiffThresh) && (mean_diff < kMeanDiffThresh); + if ( ! passed_LSE) { + std::cerr << "failed LSE: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + return passed_O && passed_LSE; + } + + void initialize_stride(cute::tuple const& shape, cute::tuple<_1, cute::tuple>& stride) { + auto [B, H, Q] = shape; + stride = make_stride(_1{}, make_stride(H*Q, Q)); + } + + void initialize_stride(cute::tuple const& shape, cute::tuple>& stride) { + auto [B, H, Q, D] = shape; + stride = make_stride(D, _1{}, make_stride(H*Q*D, Q*D)); + } + + void initialize_stride(cute::tuple const& shape, cute::tuple<_1, int, cute::tuple>& stride) { + auto [B, H, Q, D] = shape; + stride = make_stride(_1{}, Q, make_stride(H*Q*D, Q*D)); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto [B, H, Q, K, D] = problem_size; + D = cutlass::round_up(D, 8); // Alignment + + auto shape_QO = cute::make_shape(B, H, Q, D); + auto shape_KV = cute::make_shape(B, H, K, D); + auto shape_LSE = cute::make_shape(B, H, Q); + + initialize_stride(shape_QO, stride_Q); + initialize_stride(shape_KV, stride_K); + initialize_stride(shape_KV, stride_V); + initialize_stride(shape_QO, stride_O); + initialize_stride(shape_LSE, stride_LSE); + + block_Q.reset(size(shape_QO)); + block_K.reset(size(shape_KV)); + block_V.reset(size(shape_KV)); + block_O.reset(size(shape_QO)); + block_LSE.reset(size(shape_LSE)); + block_ref_O.reset(size(shape_QO)); + block_ref_LSE.reset(size(shape_LSE)); + + initialize_block(block_Q, seed + 2023, false); + initialize_block(block_K, seed + 2022, false); + initialize_block(block_V, seed + 2021, false); + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.b, options.h, options.q, options.k, options.d}; + + initialize(problem_size); + + typename Operation::Arguments arguments{ + problem_size, + { block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V }, + { block_O.get(), stride_O, + block_LSE.get(), stride_LSE }, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Kernel::SharedStorageSize; + + size_t workspace_size = Operation::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 4.0 * (std::is_same_v ? 0.5 : 1.0); + flops *= static_cast(get<0>(problem_size)); + flops *= static_cast(get<1>(problem_size)); + flops *= static_cast(get<2>(problem_size)); + flops *= static_cast(get<3>(problem_size)); + flops *= static_cast(get<4>(problem_size)); + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_s = tflops_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_size); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class TileShape, + class DispatchPolicy, + class ActiveFusion, + class... KernelOptions +> +struct BwdRunner { + + using Element = cutlass::half_t; + using ElementAccumulator = float; + + // B H Q K D + using ProblemShapeType = cute::tuple; + + using Operation = cutlass::fmha::device::FmhaBwd; + + // Just like forward + using StrideQ = cute::tuple; // B H Q D + using StrideK = cute::tuple; // B H K D + using StrideV = cute::tuple; // B H K D + using StrideO = cute::tuple; // B H Q D + using StrideLSE = cute::tuple; // B H Q + + // Backwards specific + using StrideDQ = cute::tuple; // B H Q D + using StrideDK = cute::tuple; // B H K D + using StrideDV = cute::tuple; // B H K D + using StrideDO = cute::tuple; // B H Q D + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + + StrideDQ stride_dQ; + StrideDK stride_dK; + StrideDV stride_dV; + StrideDO stride_dO; + + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_LSE; + + cutlass::DeviceAllocation block_dQ; + cutlass::DeviceAllocation block_dK; + cutlass::DeviceAllocation block_dV; + cutlass::DeviceAllocation block_dO; + + cutlass::DeviceAllocation block_ref_dQ; + cutlass::DeviceAllocation block_ref_dK; + cutlass::DeviceAllocation block_ref_dV; + + // + // Methods + // + bool verify(const ProblemShapeType& problem_size) { + auto [B, H, Q, K, D] = problem_size; + + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_Q), get<3>(stride_Q), make_stride(get<0>(stride_Q), get<1>(stride_Q)))); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_K), get<3>(stride_K), make_stride(get<0>(stride_K), get<1>(stride_K)))); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_V), get<3>(stride_V), make_stride(get<0>(stride_V), get<1>(stride_V)))); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_O), get<3>(stride_O), make_stride(get<0>(stride_O), get<1>(stride_O)))); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + make_shape(Q, make_shape(B, H)), + make_stride(get<2>(stride_LSE), make_stride(get<0>(stride_LSE), get<1>(stride_LSE)))); + + Tensor mDQ = make_tensor(make_gmem_ptr(block_ref_dQ.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_dQ), get<3>(stride_dQ), make_stride(get<0>(stride_dQ), get<1>(stride_dQ)))); + + Tensor mDK = make_tensor(make_gmem_ptr(block_ref_dK.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_dK), get<3>(stride_dK), make_stride(get<0>(stride_dK), get<1>(stride_dK)))); + + Tensor mDV = make_tensor(make_gmem_ptr(block_ref_dV.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_dV), get<3>(stride_dV), make_stride(get<0>(stride_dV), get<1>(stride_dV)))); + + Tensor mDO = make_tensor(make_gmem_ptr(block_dO.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_dO), get<3>(stride_dO), make_stride(get<0>(stride_dO), get<1>(stride_dO)))); + + + fmha_bwd_reference(problem_size, mQ, mK, mV, mO, mLSE, mDO, mDQ, mDK, mDV, ActiveFusion{}); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + double max_diff = 0; + double mean_diff = 0; + reference_abs_diff(block_dQ, block_ref_dQ, max_diff, mean_diff); + bool passed_dQ = (max_diff < 1e-2) && (mean_diff < 1e-3); + if (! passed_dQ) { + std::cerr << "failed dQ: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dK, block_ref_dK, max_diff, mean_diff); + bool passed_dK = (max_diff < 1e-2) && (mean_diff < 1e-3); + if (! passed_dK) { + std::cerr << "failed dK: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + reference_abs_diff(block_dV, block_ref_dV, max_diff, mean_diff); + bool passed_dV = (max_diff < 1e-2) && (mean_diff < 1e-3); + if (! passed_dV) { + std::cerr << "failed dV: max diff " << max_diff + << " mean " << mean_diff << std::endl; + } + + return passed_dQ && passed_dK && passed_dV; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto [B, H, Q, K, D] = problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + + auto shape_QO = cute::make_shape(B, H, Q, D); + auto shape_KV = cute::make_shape(B, H, K, D); + auto shape_LSE = cute::make_shape(B, H, Q); + + stride_Q = cute::compact_row_major(shape_QO); + stride_K = cute::compact_row_major(shape_KV); + stride_V = cute::compact_row_major(shape_KV); + stride_O = cute::compact_row_major(shape_QO); + stride_LSE = cute::compact_row_major(shape_LSE); + + stride_dQ = stride_Q; + stride_dK = stride_K; + stride_dV = stride_V; + stride_dO = stride_O; + + block_Q.reset(size(shape_QO)); + block_K.reset(size(shape_KV)); + block_V.reset(size(shape_KV)); + block_O.reset(size(shape_QO)); + block_LSE.reset(size(shape_LSE)); + + block_dQ.reset(size(shape_QO)); + block_dK.reset(size(shape_KV)); + block_dV.reset(size(shape_KV)); + block_dO.reset(size(shape_QO)); + + block_ref_dQ.reset(size(shape_QO)); + block_ref_dK.reset(size(shape_KV)); + block_ref_dV.reset(size(shape_KV)); + + initialize_block(block_Q, seed + 2023, false); + initialize_block(block_K, seed + 2022, false); + initialize_block(block_V, seed + 2021, false); + initialize_block(block_dO, seed + 2020, false); + Tensor mQ = make_tensor(make_gmem_ptr(block_Q.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_Q), get<3>(stride_Q), make_stride(get<0>(stride_Q), get<1>(stride_Q)))); + + Tensor mK = make_tensor(make_gmem_ptr(block_K.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_K), get<3>(stride_K), make_stride(get<0>(stride_K), get<1>(stride_K)))); + + Tensor mV = make_tensor(make_gmem_ptr(block_V.get()), + make_shape(K, D, make_shape(B, H)), + make_stride(get<2>(stride_V), get<3>(stride_V), make_stride(get<0>(stride_V), get<1>(stride_V)))); + + Tensor mO = make_tensor(make_gmem_ptr(block_O.get()), + make_shape(Q, D, make_shape(B, H)), + make_stride(get<2>(stride_O), get<3>(stride_O), make_stride(get<0>(stride_O), get<1>(stride_O)))); + + Tensor mLSE = make_tensor(make_gmem_ptr(block_LSE.get()), + make_shape(Q, make_shape(B, H)), + make_stride(get<2>(stride_LSE), make_stride(get<0>(stride_LSE), get<1>(stride_LSE)))); + + fmha_reference(problem_size, mQ, mK, mV, mO, mLSE, ActiveFusion{}); + } + + ExampleResult run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.b, options.h, options.q, options.k, options.d}; + + initialize(problem_size); + + typename Operation::Arguments arguments{ + problem_size, + block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_O.get(), stride_O, + block_LSE.get(), stride_LSE, + block_dO.get(), stride_dO, + block_dQ.get(), stride_dQ, + block_dK.get(), stride_dK, + block_dV.get(), stride_dV, + hw_info + }; + + Operation op; + + ExampleResult example_result; + + example_result.smem_size = Operation::Operation::Kernel::SharedStorageSize; + + size_t workspace_size = Operation::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + status = op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + // Run + cudaMemset(block_dQ.get(), 0, block_dQ.size() * sizeof(Element)); + cudaMemset(block_dK.get(), 0, block_dK.size() * sizeof(Element)); + cudaMemset(block_dV.get(), 0, block_dV.size() * sizeof(Element)); + + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result = cudaEventCreate(&event); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + } + + // Record an event at the start of a series of GEMMs + result = cudaEventRecord(events[0]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + for (int i = 0; i < options.iterations; i++) { + status = op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return example_result; + } + } + + // + // Stop profiling loop + // + + // Record an event when the GEMMs are complete + result = cudaEventRecord(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Wait for work on the device to complete. + result = cudaEventSynchronize(events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result) << std::endl; + return example_result; + } + + runtime_ms /= static_cast(options.iterations); + + double flops = 10.0 * (std::is_same_v ? 0.5 : 1.0); + flops *= static_cast(get<0>(problem_size)); + flops *= static_cast(get<1>(problem_size)); + flops *= static_cast(get<2>(problem_size)); + flops *= static_cast(get<3>(problem_size)); + flops *= static_cast(get<4>(problem_size)); + double tflops_s = flops * 1e-12 /*tera*/ / (runtime_ms * 1e-3 /*ms*/); + example_result.tflops_s = tflops_s; + example_result.runtime_ms = runtime_ms; + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return example_result; + } + + // Verify that the result is correct + bool passed = true; + if (options.verify) { + passed = verify(problem_size); + if (passed) example_result.verified = true; + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + return example_result; + } + + example_result.passed = true; + + return example_result; + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, ExampleResult result, bool verbose) { + std::ios fmt(nullptr); + fmt.copyfmt(std::cout); + std::cout << (result.passed ? (result.verified ? " [OK] " : " [--] ") : "[FAIL] "); + std::cout << std::setw(32) << std::left << description; + std::cout.copyfmt(fmt); + std::cout << " : " << result.tflops_s << " TFLOPS/s" << std::endl; + if (verbose) { + std::cout << " t=" << result.runtime_ms << "ms, " + "smem=" << result.smem_size << "b" << std::endl; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +using KernelTma = cutlass::gemm::KernelTma; +using KernelCooperative = cutlass::gemm::KernelTmaWarpSpecializedCooperative; +using KernelPingpong = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_32(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _32; + + run(Shape< _64, _128, HeadDim>{}, KernelTma{}, "tma 64x128x32"); + run(Shape< _128, _64, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x64x32"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_64(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _64; + + run(Shape< _64, _128, HeadDim>{}, KernelTma{}, "tma 64x128x64"); + run(Shape< _128, _64, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x64x64"); + run(Shape< _128, _64, HeadDim>{}, KernelPingpong{}, "tma ws ping-pong 128x64x64"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_128(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _128; + + run(Shape<_128, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x128x128"); +#ifdef FP8 + run(Shape<_128, _256, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x256x128 acc fp16", Option{}); + run(Shape<_128, _256, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x256x128 acc fp32"); +#endif + run(Shape<_128, _128, HeadDim>{}, KernelPingpong{}, "tma ws ping-pong 128x128x128"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_fwd_256(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + FwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _256; + +#ifdef FP8 + run(Shape<_128, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x128x256"); + run(Shape<_128, _128, HeadDim>{}, KernelPingpong{}, "tma ws ping-pong 128x128x256"); +#else + run(Shape<_128, _64, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x64x256"); +#endif +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_32(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _32; + + run(Shape< _64, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 64x128x32"); + run(Shape<_128, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x128x32"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_64(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _64; + + run(Shape< _64, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 64x128x64"); + run(Shape<_128, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 128x128x64"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_bwd_128(Fusion fusion, Options const & options, cutlass::KernelHardwareInfo const& hw_info) { + auto run = [&](auto shape, auto kernel, const char* name, auto... kernel_options) { + BwdRunner runner; + auto result = runner.run(options, hw_info); + print_result(name, result, options.verbose); + }; + + using HeadDim = _128; + + run(Shape<_64, _128, HeadDim>{}, KernelCooperative{}, "tma ws cooperative 64x128x128"); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main_single(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + + else if (__CUDACC_VER_MAJOR__ < 12 || (props.major != 9 || props.minor != 0)) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture " + << "(compute capability 90) and CUDA 12.0 or greater.\n"; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + std::cout << "###### B " << options.b << " H " << options.h << " Q " << options.q << " K " << options.k << " D " << options.d << " "; + std::cout << (options.bwd ? "Backward" : "Forward") << " " << (options.causal ? "Causal" : "Full") << " "; + std::cout << "#SM " << hw_info.sm_count << std::endl; + + auto with_fusion = [&](auto fn) { + if (options.causal) { + fn(CausalFusion{}); + } else if (options.residual){ + fn(ResidualFusion{}); + } else { + fn(DefaultFusion{}); + } + }; + + with_fusion([&](auto fusion) { + if (options.bwd) { +#ifndef FP8 + if (options.d <= 32) { + run_bwd_32(fusion, options, hw_info); + } else if (options.d <= 64) { + run_bwd_64(fusion, options, hw_info); + } else if (options.d <= 128) { + run_bwd_128(fusion, options, hw_info); + } else +#endif + { +#ifdef FP8 + std::cout << "Backward is not implemented for FP8." << std::endl; +#else + std::cout << "No backward kernel instantiated for d=" << options.d << std::endl; +#endif + } + } else { +#ifndef FP8 + if (options.d <= 32) { + run_fwd_32(fusion, options, hw_info); + } else + if (options.d <= 64) { + run_fwd_64(fusion, options, hw_info); + } else +#endif + if (options.d <= 128) { + run_fwd_128(fusion, options, hw_info); + } else + if (options.d <= 256) { + run_fwd_256(fusion, options, hw_info); + } + else { + std::cout << "No forward kernel instantiated for d=" << options.d << std::endl; + } + } + }); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + std::vector full_arguments(args, args + argc); + + int result = 0; + + bool recursed = false; + for (size_t i = 1; i < full_arguments.size(); i++) { + if (full_arguments[i].find(',') != std::string::npos) { + auto arg = full_arguments[i]; + size_t eq_pos = arg.find('='); + std::string prefix = eq_pos == std::string::npos ? "" : arg.substr(0, eq_pos+1); + std::string rest = eq_pos == std::string::npos ? arg : arg.substr(eq_pos+1); + for (;;) { + size_t comma_pos = rest.find(','); + std::string current = rest.substr(0, comma_pos); + full_arguments[i] = prefix + current; + std::vector next_args; + for (auto& elem : full_arguments) { next_args.push_back(elem.data()); } + main(argc, next_args.data()); + if (comma_pos == std::string::npos) break; + rest = rest.substr(comma_pos+1); + } + recursed = true; + break; + } + } + + if (! recursed) { + main_single(argc, args); + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/88_hopper_fmha/CMakeLists.txt b/examples/88_hopper_fmha/CMakeLists.txt new file mode 100644 index 00000000..e70d6788 --- /dev/null +++ b/examples/88_hopper_fmha/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) 2014 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 88_hopper_fmha + 88_hopper_fmha.cu + ) + +if(NOT WIN32 AND NOT CUTLASS_CLANG_HOST_COMPILE) + +set_property( + SOURCE 88_hopper_fmha.cu + PROPERTY COMPILE_FLAGS "--use_fast_math" + ) + +cutlass_example_add_executable( + 88_hopper_fmha_fp8 + 88_hopper_fmha.cu + ) + +target_compile_definitions( + 88_hopper_fmha_fp8 + PRIVATE FP8) + +endif() diff --git a/examples/88_hopper_fmha/README.md b/examples/88_hopper_fmha/README.md new file mode 100644 index 00000000..ad6b08f8 --- /dev/null +++ b/examples/88_hopper_fmha/README.md @@ -0,0 +1,77 @@ +# CUTLASS Hopper FMHA Example + +This sample showcases how to implement fused multi-head attention (FMHA) using +CUTLASS for the NVIDIA Hopper architecture. At its heart, the forward pass of +FMHA is a GEMM-online softmax-GEMM fusion, whereas the backward pass is a slightly +more complex structure (basically, a GEMM-softmax-2xGEMM-2xGEMM fusion). +For more information please refer to the [Flash Attention 3 paper](https://arxiv.org/abs/2407.08608). + +The forward pass kernel supports head dims 32, 64, 128, and 256 for fp16 and bf16 input data types, +and head dims 128, and 256 for fp8. +All kernels use the Tensor Memory Accelerator for loads. +Kernels with head dims 128 and 256 have warp-specialized cooperative schedules. + +Backward pass kernels (fp16 only) support head dims 32, 64, and 128, and all support +warp-specialized cooperative schedules. + +## Customization + +### Mask Fusion + +Similar to the [Blackwell FMHA example](../77_blackwell_fmha/README.md), attention masks such as +causal masking can be fused into the kernel. To modify the code for such fusions, +`collective/fmha_fusion.hpp` provides the easiest customization point. +The `before_softmax` function is called with the accumulator of the first GEMM and the logical +positions of those elements. It is well-suited for applying masks or activations. + +### MHA Variants + +Using CuTe, it is easy to represent the various attention variants. +Where regular multi-head attention's layout for the head dimension is (numHeads:headStride), +for single-head attention it is simply (1:0) everywhere, +for GQA it is normal in Q and (numHeads/numGroups,numGroups:headStride,0) in KV, +and for MQA it is normal for Q and (numHeads:0) in KV. +As such, beyond general stride handling, no additional work is needed to support these, +and the example will just demonstrate regular multi-head attention. + +### FP8 + +The warp-specialized forward kernel supports FP8 computation with both FP32 and FP16 +accumulation for the Q*K product. They can be enabled in the runner by defining FP8. + +## Performance +Forward pass kernels can generally come close to that of FA3, but backward pass +kernels are more limited in performance and are not expected to reach the same level of performance +as FA3. + +# Copyright + +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp b/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp new file mode 100644 index 00000000..d9d311a5 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp @@ -0,0 +1,863 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBwdMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + static constexpr bool kIsPersistent = false; + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = 2; + static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups; + static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */; + + static const int kOuterLoads = 2; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= 2); + static_assert(Stages::value >= 2 * NumMmaWarpGroups); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeNM = Shape< // (N,M,D) + decltype(tuple_element_t<1, TileShape>{} / Int{}), + tuple_element_t<0, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M) + + using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N) + + using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeNM, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + + using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple<_1, int, cute::tuple>, Alignment, // from smem, might matter (?) + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeMD, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaNM = typename CollectiveMmaNM::TiledMma; + using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma; + using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{})); + using TiledMmaND = TiledMmaND_RS; + using TiledMmaMD = typename CollectiveMmaMD::TiledMma; + + using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB; + using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB; + + //using SmemLayoutDQ = Layout< + // Shape< + // tuple_element_t<0, TileShapeMD>, + // Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>, + // _2 + // >, + // Stride< + // _4, + // Stride{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>, + // decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + // >>; + + using SmemLayoutDQ_0 = Layout< + Shape< + tuple_element_t<0, TileShapeMD>, + tuple_element_t<1, TileShapeMD>, + _2 + >, + Stride< + tuple_element_t<1, TileShapeMD>, + _1, + decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + >>; + + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>()); + using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{})); + using SmemLayoutDQ = SmemLayoutDQ_1; + + + using PipelineDQ = cutlass::PipelineAsync<2>; + + + using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int{})); + + using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{})); + using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB; + + using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB; + using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB; + + using SmemLayoutLSE = Layout, Int>>; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapePV = TileShapeND; // To work with the kernel level + using TiledMmaPV = TiledMmaND; + + static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator); + static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + + struct SharedStorage { + // One for each consumer WG + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_kp; + cute::array_aligned> smem_v; + }; + + cute::array_aligned> smem_ds; + + // Loaded by producer, consumed by both WGs + union { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + cute::array_aligned> smem_qp; + cute::array_aligned> smem_dop; + }; + + // Accumulated into by both consumers, potentially loaded, potentially written + cute::array_aligned> smem_dq; + + union { + cute::array_aligned> smem_lse; + cute::array_aligned> smem_sumOdO; + }; + }; + + struct Arguments { + const Element* ptr_Q; + cute::tuple dQ; + const Element* ptr_K; + cute::tuple dK; + const Element* ptr_V; + cute::tuple dV; + + const Element* ptr_dO; + cute::tuple dDO; + + const ElementAccumulator* ptr_LSE; + cute::tuple dLSE; + const ElementAccumulator* ptr_sum_OdO; + cute::tuple dSumOdO; + + ElementAccumulator* ptr_dQ; + cute::tuple dDQ; + }; + + using TMA_Q = typename CollectiveMmaNM::Params::TMA_B; + using TMA_K = typename CollectiveMmaNM::Params::TMA_A; + using TMA_V = typename CollectiveMmaNM::Params::TMA_A; + using TMA_DO = typename CollectiveMmaNM::Params::TMA_B; + + using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{}))); + using TMA_ODO = TMA_LSE; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{}))); + + using LoadQ = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutV, + TMA_V + >; + + using LoadDO = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutDO, + TMA_DO + >; + + using LoadLSE = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_LSE + >; + + using LoadODO = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_ODO + >; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_DO tma_load_do; + + TMA_LSE tma_load_lse; + TMA_ODO tma_load_odo; + + TMA_DQ tma_red_dq; + + float scale_softmax; + float scale_softmax_log2; + }; + + static_assert(size(TiledMmaNM{}) == size(TiledMmaND{})); + static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ))); + auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_K, dK, + args.ptr_Q, dQ, + }, /*workspace=*/ nullptr); + + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV))); + auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO))); + auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_V, dV, + args.ptr_dO, dDO, + }, /*workspace=*/ nullptr); + + + TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{})); + TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{})); + + TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{})); + + return Params{ + params_nm_kq.tma_load_b, + params_nm_kq.tma_load_a, + params_nm_vdo.tma_load_a, + params_nm_vdo.tma_load_b, + tma_load_lse, tma_load_odo, + tma_red_dq, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))) + }; + } + + template + CUTLASS_DEVICE + auto + get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) { + return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count); + + uint16_t mcast_mask_b = 0; + + LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q}; + auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do}; + auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse}; + auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO}; + auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + inner_tile_count *= 4; // Q & dO & LSE & sumOdO + + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + + if constexpr (kLoadOuter) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (! kLoadOuter) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + if constexpr (kLoadOuter) { + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer, + SharedStorage& storage) + { + int lane_predicate = cute::elect_one_sync(); + + Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size)); + Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{}); + Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord)); + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + + auto block_tma = params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + int g_index = 0; + + auto smem_pipe_release_reducer = smem_pipe_read_reducer; + bool first = true; + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + ++g_index; + } + if (inner_tile_count == 0) break; + + pipeline_reducer.consumer_wait(smem_pipe_read_reducer); + if (lane_predicate == 1) { + tma_store_wait<1>(); + } + if (! first) { + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } else { + first = false; + } + if (lane_predicate == 1) { + copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index)); + tma_store_arrive(); + } + ++smem_pipe_read_reducer; + --inner_tile_count; + ++g_index; + } + if (lane_predicate) { + tma_store_wait<0>(); + } + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + TiledMmaND tiled_mma_nd; + + Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DV); + + Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DK); + + int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup; + + PipelineState smem_pipe_release_inner = smem_pipe_read_inner; + + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer; + ++smem_pipe_read_outer; + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer; + + int inner_tile_count = get_inner_tile_count(wg_coord, problem_size); + + TiledMmaNM tiled_mma_nm; + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx); + Tensor tSsK = thr_mma_nm.partition_A(sK); + Tensor tSsQ = thr_mma_nm.partition_B(sQ); + Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK); + Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ); + + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{}); + + Tensor tDPsV = thr_mma_nm.partition_A(sV); + Tensor tDPsDO = thr_mma_nm.partition_B(sDO); + Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV); + Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO); + + auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx); + + Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{}); + Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp); + Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO); + + + Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{}); + Tensor tDK_sQ = thr_mma_nd.partition_B(sQp); + Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ); + + + int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0); + + TiledMmaMD tiled_mma_md; + auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx); + Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{}); + Tensor tDQsDS = thr_mma_md.partition_A(sDS); + Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS); + Tensor tDQrDS = tDQrDS_full(_,_,_,_); + Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{}); + Tensor tDQsK = thr_mma_md.partition_B(sKp); + Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK); + + Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tSsLSE = thr_mma_nm.partition_C(sLSE); + + Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tDPsODO = thr_mma_nm.partition_C(sODO); + + Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{})); + Tensor tScS = thr_mma_nm.partition_C(cS); + int n_block = get<1>(wg_coord); + tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{}); + + + // Transpose + Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS)))); + Tensor sDSp = sDSp_full(_,_,_); + Tensor tDPsDS = thr_mma_nm.partition_C(sDSp); + + auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx); + Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp); + + Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp); + + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + auto tDQsDQ_full = thr_mma_md.partition_C(sDQ); + + + auto smem_pipe_read_k_other = smem_pipe_read_k; + smem_pipe_read_k_other.advance(2); + + int k_index = 0; + + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + if (inner_tile_count == 0) break; + + pipeline_inner.consumer_wait(smem_pipe_read_inner); + PipelineState smem_pipe_read_q = smem_pipe_read_inner; + ++smem_pipe_read_inner; + PipelineState smem_pipe_read_do = smem_pipe_read_inner; + ++smem_pipe_read_inner; + + // GEMM KQ -> S + Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_S); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S); + warpgroup_commit_batch(); + + pipeline_inner.consumer_wait(smem_pipe_read_do); + + // GEMM VdO -> dP + Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_DP); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP); + warpgroup_commit_batch(); + + Tensor reg_LSE = make_fragment_like(acc_S); + for (int i = 0; i < size(reg_LSE); i++) { + reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i); + } + + Tensor reg_ODO = make_fragment_like(acc_S); + if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_S); + + math_wg_order_barrier.wait(); + // Compute S -> P + Fusion{}.before_softmax(acc_S, tScS, problem_size); + auto acc_P = make_fragment_like(acc_S); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_P); i++) { + acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i)); + } + math_wg_order_barrier.arrive(); + + if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DP); + + // Compute dP P -> dS + auto acc_DS = make_fragment_like(acc_DP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS); i++) { + // We could move the scale out and into the respective epilogues (or a final scaling step) + acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i)); + } + + // GEMM PdO -> dV + auto op_P = make_acc_into_op(acc_P, typename TiledMmaND::LayoutA_TV{}); + warpgroup_fence_operand(acc_DV); + warpgroup_fence_operand(op_P); + warpgroup_arrive(); + cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV); + warpgroup_commit_batch(); + + // Store dS to smem dS' + if (wg_idx == 0) math_wg_order_barrier.wait(); + + auto recast_bits = [](auto sz, auto t) { + return recast>(t); + }; + auto tDPsDS_v = recast_bits(Int * 2>{}, tDPsDS); + auto acc_DS_v = recast_bits(Int * 2>{}, acc_DS); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS_v); i++) { + tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i); + } + + cutlass::arch::fence_view_async_shared(); + if (wg_idx == 0) math_wg_order_barrier.arrive(); + + // GEMM dS Q -> dK + if (wg_idx == 1) { + + math_wg_order_barrier.wait(); + + // GEMM dS' K -> dQ + Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{})); + + warpgroup_fence_operand(acc_DQ); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ); + cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ); + warpgroup_commit_batch(); + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DQ); + + math_wg_order_barrier.arrive(); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index()); + + // Store dQ to smem dQ' + // Invoke TMA reduce on dQ' + using Vec = uint_bit_t * 2>; + auto tDQsDQ_v = recast(tDQsDQ); + auto acc_DQ_v = recast(acc_DQ); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DQ_v); i++) { + tDQsDQ_v(i) = acc_DQ_v(i); + } + + cutlass::arch::fence_view_async_shared(); + + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } else { + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } + + --inner_tile_count; + + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + + pipeline_outer.consumer_release(smem_pipe_read_k); + pipeline_outer.consumer_release(smem_pipe_read_outer); + pipeline_reducer.producer_tail(smem_pipe_write_reducer); + ++smem_pipe_read_outer; + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DK); + warpgroup_fence_operand(acc_DV); + + return make_tuple(acc_DK, acc_DV); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_collective_load.hpp b/examples/88_hopper_fmha/collective/fmha_collective_load.hpp new file mode 100644 index 00000000..55029faf --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_collective_load.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +enum class LoadKind { + kQ, kK, kV, + kBwdN, kBwdM, kBwdScalar +}; + +template< + LoadKind kKind, + class Pipeline, + class Element, + class SmemLayout, + class TMA +> +struct CollectiveLoadTma { + + using Params = TMA; + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + Params const& params; + Pipeline& pipeline; + SharedStorage& storage; + + CUTLASS_DEVICE + CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage) + : params(params), pipeline(pipeline), storage(storage) {} + + template + CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& blk_coord, int loop_count + ) { + using X = Underscore; + if constexpr (kKind == LoadKind::kK) { + Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord)); + return gK; + } else if constexpr (kKind == LoadKind::kQ) { + Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout()); + } else if constexpr (kKind == LoadKind::kV) { + Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size))); + Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord)); + return gV; + } else if constexpr (kKind == LoadKind::kBwdN) { + Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout()); + } else if constexpr (kKind == LoadKind::kBwdM) { + Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return g; + } else if constexpr (kKind == LoadKind::kBwdScalar) { + Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size)); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord)); + return g; + } + } + + template + CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster, + ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& block_coord, int loop_count + ) { + Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count); + Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{}); + + auto block_tma = params.get_slice(block_rank_in_cluster); + Tensor ts = block_tma.partition_D(s); + Tensor tg = block_tma.partition_S(g); + + return make_tuple(tg, ts); + } + + template + CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state, + PipelineState& smem_pipe_write, + int lane_predicate, int& tile_count, uint16_t mcast_mask = 0 + ) { + if ((lane_predicate == 1) && (tile_count > 0)) { + if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename Pipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + if constexpr (kKind == LoadKind::kBwdScalar) { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index())); + } else { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index())); + } + if constexpr (kAdvancePipe) ++smem_pipe_write; + if constexpr (kAdvanceIterator) ++tile_iter; + } + --tile_count; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp b/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp new file mode 100644 index 00000000..9c958da0 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template< + class ElementAccumulator, + class Fusion, + class Params +> +struct CollectiveSoftmax { + Params const& params; + CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {} + + using SumType = float; + using MaxType = ElementAccumulator; + + template + CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) { + Tensor s_max = make_fragment_like(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout()))); + Tensor a_sum = make_fragment_like(s_max); + return make_tuple(s_max, a_sum); + } + + CUTLASS_DEVICE float overload_exp2(float f) { + return ::exp2f(f); + } + + CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) { + auto a = f.raw(); + decltype(a) d; + asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a)); + return cutlass::half_t::bitcast(d); + } + + + CUTLASS_DEVICE float overload_max(float a, float b) { + return ::max(a, b); + } + + CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) { + return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())}; + } + + CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) { + return f.to_half(); + } + + CUTLASS_DEVICE float overload_to_native(float f) { + return f; + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + // Linear reduction is faster for the first iteration + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = acc_qk_mn(i, 0); + } + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(acc_qk_mn); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + template + CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride(reduction_target_qk) * j)); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i); + float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2); + a_sum(i) *= scale; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale; + } + } + } + + template + CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) { + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(acc_qk_mn); j++) { + float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j); + float scale_max = params.scale_softmax_log2 * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<1>(acc_qk_mn); k++) { + acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max); + a_sum(j) += acc_qk_mn(j, k); + } + } + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + }); + + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + + MaxType s_max_cur = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale); + a_sum(i) *= scale_pv; + + using ElementPV = typename AccPV::value_type; + ElementPV scale_pv_ele = static_cast(scale_pv); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale_pv_ele; + } + a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + + template + CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) { + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + auto reduction_target = reduction_target_n(tiled_mma_pv); + constexpr int red_rank = decltype(rank(reduction_target))::value; + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride(reduction_target) * j); + } + } + }); + + Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + Tensor lse = make_fragment_like(a_sum); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_mn); i++) { + float sum = a_sum(i); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum); + lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum); + float scale = params.rp_dropout * inv_sum; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_mn); j++) { + acc_mn(i, j) *= scale; + } + } + + return lse; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp b/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp new file mode 100644 index 00000000..c3dfda63 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaMainloopTma { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + // Options + using kClusterM = find_option_t, Options...>; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape; + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = TileShape; + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using LayoutQKV = cute::tuple>; + using LayoutQ = LayoutQKV; + using LayoutK = LayoutQKV; + using LayoutV = LayoutQKV; + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulator, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulator, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulator; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{}); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE auto + compute( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage) + { + int warp_idx = cutlass::canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; + + PipelineState smem_pipe_release = smem_pipe_read; + [[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + // Set predicate for the lowest lane_id in the warp + int lane_predicate = cute::elect_one_sync(); + + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0) { + auto q_tile_iter = cute::make_coord_iterator(1); + int q_tile_count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + // Loop over K elems + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + + int k_tile_count_tma = 2 * fusion_tile_count; + + uint16_t mcast_mask_b = 0; + + if (warp_idx == 0 && lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < StageCount; i++) { + if (i % 2 == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } else { + load_v.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + } + } + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(blk_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_fragment_like(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}))); + + Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout()); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + acc_qk_input(i) = static_cast(acc_qk(i)); + } + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + // + // Advance the pipe + // + + // Advance consumer pipeline + ++smem_pipe_read; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.template step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective + diff --git a/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp b/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp new file mode 100644 index 00000000..9fd45baf --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp @@ -0,0 +1,560 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // SeqQ, SeqKV, Head + class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches) + class Fusion, + class... Options +> +struct FmhaMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorPV = ElementAccumulatorPV_; + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; + using LayoutK = LayoutK_; + using LayoutV = LayoutV_; + + // Options + static constexpr bool kIsPersistent = find_option_t::value; + static constexpr bool kIsMainloopLocked = find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = find_option_t, Options...>::value; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + static const int kOuterLoads = 1; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= NumMmaWarpGroups); + static_assert(Stages::value >= 2); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = Shape< + decltype(tuple_element_t<0, TileShape>{} / Int{}), + tuple_element_t<1, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulatorQK, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulatorPV, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulatorPV; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + int lane_predicate = cute::elect_one_sync(); + + uint16_t mcast_mask_b = 0; + + if (lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + } + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + [[maybe_unused]] int q_tile_count = NumMmaWarpGroups; + + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + int k_tile_count = 2 * fusion_tile_count; + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + if constexpr (! kLoadQ) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + while (q_tile_count > 0) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int lane_predicate = cute::elect_one_sync(); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + + CUTLASS_PRAGMA_UNROLL + for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) { + int count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count); + if (q_tile_count == 0 && do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage) + { /* no-op */ } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, + MainloopPipelineReducer&, PipelineStateReducer&, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + int thread_idx = int(threadIdx.x); + + PipelineState smem_pipe_release = smem_pipe_read; + PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(wg_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + math_wg_order_barrier.wait(); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + math_wg_order_barrier.arrive(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + // Advance consumer pipeline + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.template step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + //if constexpr (kIsPersistent) + // if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + if (kIsPersistent) pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_common.hpp b/examples/88_hopper_fmha/collective/fmha_common.hpp new file mode 100644 index 00000000..40c23f48 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_common.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + if constexpr (rA == 2 && rB == 2 && rC == 1) { + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<1>(tA); k_block++) { + cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } else { + static_assert(rA == 3 && rB == 3 && rC == 3); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = GMMA::ScaleOut::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) { + if constexpr (decltype(size(t) % _2{} == _0{})::value) { + auto partial = make_tensor(size(t) / _2{}); + CUTE_UNROLL + for (int i = 0; i < size(partial); i++) { + partial(i) = fn(t(i), t(i + size(partial))); + } + return reduce(partial, fn); + } else { + auto result = t(_0{}); + CUTE_UNROLL + for (int i = 1; i < size(t); i++) { + result = fn(result, t(i)); + } + return result; + } +} + +struct fmha_max { + CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); } +}; + +template +inline auto __device__ constexpr layout_separate(Threshold const& thr, + Source const& src, Reference const& ref) { + auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r < thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r >= thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + return make_layout(lt, ge); +} + +template +inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{})); + auto V_M = get<0>(separated); + auto V_N = get<1>(separated); + return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc))); +} + +template +inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) { + return layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{})); +} + +template +inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) { + return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout())); +} + +template +inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + make_layout(shape<0>(typename TiledMma::LayoutC_TV{})), + stride<0>(typename TiledMma::LayoutC_TV{})); + return get<1>(separated); +} + + +template class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + constexpr auto tA = cute::GMMA::Major::K; + constexpr auto tB = cute::GMMA::Major::K; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template +CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA const& tiled_mma) { + return cute::TiledMMA{}; +} + +template +CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) { + return make_layout( + make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))), + make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c)))); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, make_layout(stages))); +} + +template +CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) { + Tensor operand = make_fragment_like(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv))); + Tensor operand_as_acc = make_tensor(operand.data(), acc.layout()); + + cute::copy(acc, operand_as_acc); + + if constexpr (sizeof(Element) == 1) { + + // 00 11 22 33 00 11 22 33 acc layout + // 00 00 11 11 22 22 33 33 operand layout + // BB AA AA BB AA BB BB AA conflict-free exchange pattern + // 16-bit exchange; so process two at a time potentially + int tid = threadIdx.x % 4; + auto values_u32 = recast(operand); + + CUTE_UNROLL + for (int n = 0; n < size<1>(values_u32); n++) { + CUTE_UNROLL + for (int k = 0; k < size<2>(values_u32); k++) { + CUTE_UNROLL + for (int ii = 0; ii < 8; ii += 4) { + + uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k); + uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k); + + // step A: + // t 1 v 0 -> t 0 v 1 + // t 2 v 0 -> t 1 v 0 + // t 0 v 1 -> t 2 v 0 + // t 3 v 1 -> t 3 v 1 + + int v_to_send = tid == 1 || tid == 2 ? 0 : 1; + int v_to_recv = v_to_send; + int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4); + + // step B: + // t 0 v 0 -> t 0 v 0 + // t 3 v 0 -> t 1 v 1 + // t 1 v 1 -> t 2 v 1 + // t 2 v 1 -> t 3 v 0 + + v_to_send = 1 - v_to_send; + v_to_recv = 1 - v_to_recv; + t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4); + + values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410); + values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632); + } + } + } + } + + return operand; +} + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_epilogue.hpp b/examples/88_hopper_fmha/collective/fmha_epilogue.hpp new file mode 100644 index 00000000..246a625f --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_epilogue.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaFwdEpilogue { + + static constexpr int Alignment = 16 / sizeof(Element); + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Arguments { + Element* ptr_O; + cute::tuple> dO; + + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + }; + + struct Params { + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + + typename CollectiveEpilogueTMA::Params epilogue_TMA; + }; + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO}; + return Params{ + args.ptr_LSE, args.dLSE, + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, + TensorStorage& epi_tensor_storage) + { + using X = Underscore; + + auto acc = get<0>(result); + auto lse = get<1>(result); + + auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x); + + int seqlen_q = get<2>(problem_size); + int num_batch = get<0>(problem_size); + int num_heads = get<1>(problem_size); + // Epilogue for lse + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), + make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)), + make_stride(_1{}, _0{}, get<1>(params.dLSE))); + Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord)); + Tensor tOgLSE = thr_mma.partition_C(gLSE); + Tensor cO = make_identity_tensor(take<0,2>(tile_shape)); + Tensor tOcO = thr_mma.partition_C(cO); + if (get<1>(tOcO(_0{})) == 0) { + auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout())); + auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout())); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tOgLSE_mn); i++) { + if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) { + tOgLSE_mn(i, _0{}) = lse(i); + } + } + } + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_tma.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage + ); + + epilogue_tma.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp b/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp new file mode 100644 index 00000000..67772252 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "../collective/fmha_epilogue.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaBwdEpilogueKV { + + static constexpr int Alignment = 16 / sizeof(Element); + + struct Arguments { + Element* ptr_K; + cute::tuple dK; + + Element* ptr_V; + cute::tuple dV; + }; + + //using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, + cutlass::epilogue::fusion::Sm90AccFetch + >; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Params { + typename CollectiveEpilogueTMA::Params epilogue_K; + typename CollectiveEpilogueTMA::Params epilogue_V; + }; + + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2]; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), + make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), + make_stride(get<0>(args.dV), get<1>(args.dV))); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK}; + typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV}; + return Params{ + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr), + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage) + { + auto acc_k = get<0>(result); + auto acc_v = get<1>(result); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]}; + CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]}; + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_k.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[0] + ); + + } + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_v.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[1] + ); + + epilogue_k.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + + epilogue_v.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/collective/fmha_fusion.hpp b/examples/88_hopper_fmha/collective/fmha_fusion.hpp new file mode 100644 index 00000000..ac51f250 --- /dev/null +++ b/examples/88_hopper_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct DefaultFusion { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<3>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 0; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + return; + } +}; + +struct ResidualFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 1; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<3>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct CausalFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<0>(pos) < get<1>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +template +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size)); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + auto index_base = index_qk(_0{}); + auto index_shape = shape(index_qk); + auto index_stride = transform_leaf(stride(index_qk), [](auto elem) { + if constexpr (is_scaled_basis::value) { + if constexpr(decltype(elem.mode() == _0{})::value) { + return ScaledBasis(elem.value()); + } else { + return ScaledBasis(elem.value()); + } + } else { + return elem; + } + }); + auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride)); + Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size); + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return true; + } +}; + +template<> +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get<2>(problem_size) / get<0>(TileShape{}); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) < get<0>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape); + int min_k = get<1>(blk_coord) * get<1>(tile_shape); + return min_k <= max_q; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/examples/88_hopper_fmha/device/device_universal.hpp b/examples/88_hopper_fmha/device/device_universal.hpp new file mode 100644 index 00000000..beadb688 --- /dev/null +++ b/examples/88_hopper_fmha/device/device_universal.hpp @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class Universal { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/88_hopper_fmha/device/fmha_device_bwd.hpp b/examples/88_hopper_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 00000000..47c43b71 --- /dev/null +++ b/examples/88_hopper_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// common +#include "cutlass/cutlass.h" + +#include "../device/device_universal.hpp" +#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/fmha_epilogue_bwd.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_Q; + cute::tuple stride_Q; + const Element* ptr_K; + cute::tuple stride_K; + const Element* ptr_V; + cute::tuple stride_V; + + const Element* ptr_O; + cute::tuple stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple stride_LSE; + + const Element* ptr_dO; + cute::tuple stride_dO; + + Element* ptr_dQ; + cute::tuple stride_dQ; + Element* ptr_dK; + cute::tuple stride_dK; + Element* ptr_dV; + cute::tuple stride_dV; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::device::Universal>; + using OperationConvert = cutlass::device::Universal>; + + using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized< + Element, ElementAccumulator, TileShape, + cutlass::fmha::collective::FusionBwdAdapter, Options...>; + + using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV; + + using Operation = cutlass::device::Universal< + cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized< + Mainloop, + Epilogue, + cutlass::fmha::kernel::TileSchedulerBwdAdapter, Options...>>; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_sum_OdO = make_stride(H*Q, Q, _1{}); + return typename OperationSumOdO::Arguments { + args.problem_size, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + dest, stride_sum_OdO + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{}); + return typename OperationConvert::Arguments { + args.problem_size, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple const& stride_sum_OdO = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple const& stride_dQ = {} + ) { + return typename Operation::Arguments{ + args.problem_size, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + args.ptr_LSE, args.stride_LSE, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp new file mode 100644 index 00000000..bd0e0f9d --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../collective/fmha_collective_tma.hpp" +#include "../collective/fmha_collective_tma_warpspecialized.hpp" +#include "../collective/fmha_epilogue.hpp" +#include "../kernel/fmha_kernel_tma.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // BlockQO, BlockKV, BlockHead + class LayoutQ_, + class LayoutK_, + class LayoutV_, + class Fusion, + class DispatchPolicy, + class... Options +> +struct FmhaBuilder; + +template< + class Element, + class ElementAccumulator, + class TileShape, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulator, + ElementAccumulator, + TileShape, + cute::tuple>, + cute::tuple>, + cute::tuple>, + Fusion, + cutlass::gemm::KernelTma, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTma; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, LayoutQ, LayoutK, LayoutV, + Fusion, Options...>; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>; + + static constexpr bool kIsPersistent = find_option_t::value; + using TileScheduler = std::conditional_t; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + Options... +> { + using Kernel = typename FmhaBuilder< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, + LayoutQ, LayoutK, LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options..., + Option, + Option + >::Kernel; +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 00000000..3ef1946b --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + tuple problem_size; + + const ElementAccumulator* ptr_src_dQ; + tuple stride_src_dQ; + const ElementAccumulator* ptr_src_dK; + tuple stride_src_dK; + const ElementAccumulator* ptr_src_dV; + tuple stride_src_dV; + + Element* ptr_dest_dQ; + tuple stride_dest_dQ; + Element* ptr_dest_dK; + tuple stride_dest_dK; + Element* ptr_dest_dV; + tuple stride_dest_dV; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y; + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= count) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAccumulator value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = value_src[v]; + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 00000000..4e276a62 --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_O; + cute::tuple stride_O; + const Element* ptr_dO; + cute::tuple stride_dO; + + ElementAccumulator* ptr_sum_OdO; + cute::tuple stride_sum_OdO; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO); + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= get<2>(params.problem_size)) continue; + ElementAccumulator acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = acc; + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp new file mode 100644 index 00000000..528e83cb --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_tile_scheduler.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class... Options +> +struct FmhaKernelTma { + + // Options + static constexpr int kBlocksPerSM = find_option_t, Options...>::value; + + using Element = typename CollectiveMainloop::Element; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + + using TileScheduler = IndividualTileScheduler; + + using StagesQ = typename CollectiveMainloop::StagesQ; + using Stages = typename CollectiveMainloop::Stages; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ; + + using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ; + using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK; + + struct SharedStorage { + union { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + alignas(16) PipelineStorageQ pipeline_storage_q; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = kBlocksPerSM; + static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int thread_idx = int(threadIdx.x); + + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + + PipelineParamsQ pipeline_params_q; + pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q + pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer; + pipeline_params_q.is_leader = warp_group_thread_idx == 0; + pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup; + + MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{}); + MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{}); + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer; + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read; + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + + PipelineStateQ smem_pipe_read_q; + PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + auto blk_coord = tile_scheduler.get_block_coord(); + + CollectiveMainloop collective_mainloop; + auto result = collective_mainloop.compute( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline, smem_pipe_read, smem_pipe_write, + pipeline_q, smem_pipe_read_q, smem_pipe_write_q, + storage.mainloop + ); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.epilogue); + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp b/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp new file mode 100644 index 00000000..1e760a3e --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp @@ -0,0 +1,418 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class... Options +> +struct FmhaKernelTmaWarpSpecialized { + + // Options + static constexpr bool kIsEpilogueLocked = find_option_t::value; + static constexpr bool kLoadsQSeparately = find_option_t::value; + + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ; + using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineReducer = cutlass::PipelineAsync<2>; + + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< + StagesPerMathWarpGroup, NumMmaWarpGroups>; + + struct TensorStorageStruct { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + union TensorStorageUnion { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + using TensorStorage = std::conditional_t; + + struct SharedStorage { + TensorStorage tensors; + + using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage; + using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage; + using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage; + + alignas(16) PipelineStorageInner pipeline_storage_inner; + alignas(16) PipelineStorageOuter pipeline_storage_outer; + alignas(16) PipelineStorageReducer pipeline_storage_reducer; + + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + + alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParamsInner = typename MainloopPipelineInner::Params; + using PipelineStateInner = typename cutlass::PipelineState; + using PipelineParamsOuter = typename MainloopPipelineOuter::Params; + using PipelineStateOuter = typename cutlass::PipelineState; + using PipelineParamsReducer = typename MainloopPipelineReducer::Params; + using PipelineStateReducer = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup; + using ArchTag = cutlass::arch::Sm90; + + static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8; + static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2, + Consumer2 = 3, + Consumer3 = 4, + }; + enum class ProducerWarpRole { + LoadKV = 1, + Reducer = 0, + MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it) + MainloopEpilogue = 3, + }; + + static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV; + + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int lane_idx = cutlass::canonical_lane_idx(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0; + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + PipelineParamsOuter pipeline_params_outer; + pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes; + pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ); + pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParamsInner pipeline_params_inner; + pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes; + pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV); + pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + + PipelineParamsReducer pipeline_params_reducer; + pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) { + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) { + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer; + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer; + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + + MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{}); + MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{}); + MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineStateInner smem_pipe_read_inner; + PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state(); + + PipelineStateOuter smem_pipe_read_outer; + PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state(); + + PipelineStateReducer smem_pipe_read_reducer; + PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state(); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = consumer_warp_group_idx; + params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier); + + // Epilogue Load pipeline + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + if constexpr (kLoadsQSeparately) { + if ((warp_idx == 0) && lane_predicate) { + storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp); + } + cutlass::arch::fence_barrier_init(); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + if (producer_warp_role == ProducerWarpRole::LoadKV) { + bool do_barrier = kLoadsQSeparately; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.template load_kv_maybe_q( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_write_inner, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } + else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) { + bool do_barrier = true; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.load_maybe_q( + blk_coord, params.mainloop, params.problem_size, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } else if (producer_warp_role == ProducerWarpRole::Reducer) { + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.reduce( + blk_coord, params.mainloop, params.problem_size, + pipeline_reducer, smem_pipe_read_reducer, + storage.tensors.mainloop + ); + } + } + } + else if ( + warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + cutlass::arch::warpgroup_reg_alloc(); + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto wg_coord = blk_coord; + + constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads; + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(0 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(1 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(2 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(3 * kOuterLoads); + } + + constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1; + auto& wg_block = get(wg_coord); + if (warp_group_role == WarpGroupRole::Consumer0) { + wg_block = NumMmaWarpGroups * wg_block + 0; + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + wg_block = NumMmaWarpGroups * wg_block + 1; + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + wg_block = NumMmaWarpGroups * wg_block + 2; + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + wg_block = NumMmaWarpGroups * wg_block + 3; + } + + auto result = collective_mainloop.compute( + blk_coord, wg_coord, + params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_read_inner, + pipeline_outer, smem_pipe_read_outer, + pipeline_reducer, smem_pipe_write_reducer, + storage.tensors.mainloop, + math_wg_order_barrier + ); + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0)); + } + if constexpr (NumMmaWarpGroups >= 2) { + if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1)); + } + } + if constexpr (NumMmaWarpGroups >= 3) { + if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2)); + } + } + if constexpr (NumMmaWarpGroups >= 4) { + if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3)); + } + } + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait(); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]); + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive(); + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/include/cute/tensor_predicate.hpp b/examples/88_hopper_fmha/kernel/fmha_options.hpp similarity index 63% rename from include/cute/tensor_predicate.hpp rename to examples/88_hopper_fmha/kernel/fmha_options.hpp index d39f6ada..6746bf79 100644 --- a/include/cute/tensor_predicate.hpp +++ b/examples/88_hopper_fmha/kernel/fmha_options.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,51 +28,56 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + #pragma once -#include // CUTE_HOST_DEVICE -#include // cute::true_type +#include "cutlass/cutlass.h" -namespace cute -{ +namespace cutlass::fmha::kernel { -template -struct ConstantTensor -{ - template - CUTE_HOST_DEVICE constexpr - T const& - operator()(Coords const&...) const { - return val_; - } +template +struct find_option; - T val_; +template +struct find_option { + using option_value = Default; }; -struct TrivialPredTensor -{ - template - CUTE_HOST_DEVICE constexpr - true_type - operator()(Coords const&...) const { - return {}; - } +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK }; -template -struct FunctionPredTensor -{ - CUTE_HOST_DEVICE constexpr - FunctionPredTensor(Fn const& fn) : fn_(fn) {} - - template - CUTE_HOST_DEVICE constexpr - auto - operator()(Coords const&... coords) const { - return fn_(coords...); - } - - Fn const& fn_; +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; }; -} // end namespace cute +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp b/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 00000000..0b46e27e --- /dev/null +++ b/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,204 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_h; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct TileSchedulerBwdAdapter { + + using Params = typename Base::Params; + + Base base_; + + CUTLASS_DEVICE + TileSchedulerBwdAdapter(Params const& params) : base_(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape)); + } + + static dim3 get_grid_shape(Params const& params) { + return Base::get_grid_shape(params); + } + + CUTLASS_DEVICE + bool is_valid() { + return base_.is_valid(); + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return select<1,0,2>(base_.get_block_coord()); + } + + CUTLASS_DEVICE + TileSchedulerBwdAdapter& operator++() { + ++base_; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp b/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 00000000..503fb3aa --- /dev/null +++ b/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + } + mDQ(idx_Q, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + } + mDK(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L); + } + mDV(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dQ_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDK), size<2>(mDK), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dK_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDV), size<2>(mDV), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dV_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion +) { + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/88_hopper_fmha/reference/fmha_reference.hpp b/examples/88_hopper_fmha/reference/fmha_reference.hpp new file mode 100644 index 00000000..94f44371 --- /dev/null +++ b/examples/88_hopper_fmha/reference/fmha_reference.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void __global__ fmha_reference_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + auto id = make_identity_tensor(make_shape(1, 1)); + for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) { + acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L); + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = static_cast(frag(0) * softmax_scale); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + mS[idx_K] = static_cast(exp(mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + sum += mS[idx_K]; + } + + Element scale = static_cast(1.0 / sum); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale; + } + mO(idx_Q, idx_D, idx_L) = static_cast(acc); + } + + if (threadIdx.x == 0) { + mLSE(idx_Q, idx_L) = log(sum) + maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void fmha_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_reference_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/88_hopper_fmha/reference/reference_abs_error.hpp b/examples/88_hopper_fmha/reference/reference_abs_error.hpp new file mode 100644 index 00000000..5eb2413a --- /dev/null +++ b/examples/88_hopper_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include +#include "cutlass/util/device_memory.h" + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff +) { + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + cutlass::DeviceAllocation const& data, + cutlass::DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff +) { + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + cutlass::DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index bfee2c3c..978105ea 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -163,6 +163,7 @@ foreach(EXAMPLE 82_blackwell_distributed_gemm 83_blackwell_sparse_gemm 84_blackwell_narrow_precision_sparse_gemm + 88_hopper_fmha ) add_subdirectory(${EXAMPLE}) diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index 3c9e93c4..7c604a51 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -55,3 +55,7 @@ cutlass_example_add_executable( tiled_copy.cu ) +cutlass_example_add_executable( + cute_tutorial_tiled_copy_if + tiled_copy_if.cu +) diff --git a/examples/cute/tutorial/tiled_copy_if.cu b/examples/cute/tutorial/tiled_copy_if.cu new file mode 100644 index 00000000..17d7de1a --- /dev/null +++ b/examples/cute/tutorial/tiled_copy_if.cu @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +// This example extends `tiled_copy` using predicate tensors to guard memory accesses performed +// by `cute::copy_if()`. This enables tensors to have shapes that are not integer multiples of +// block sizes. +// +// This is accomplished by instantiating a tensor of coordinates which correspond to tensor elements +// to be accessed and then computing a predicate tensor which masks accesses. The example demonstrates +// how constructing of an identity tensor containing coordinates and a predicate tensor containing +// mask bits can be implemented using the same CuTe operations used to tile the tensors in +// Global Memory. +// +// This example implements two variants: +// - copy_if_kernel() uses `cute::local_partition()` to construct each thread's slice +// - copy_if_kernel_vectorized() uses `make_tiled_copy() to implement vectorized memory accesses. +// +// The tensor shapes and strides must be divisible by the shape of the vector access. +// + +/// Simple copy kernel. +// +// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N). +template +__global__ void copy_if_kernel(TensorS S, TensorD D, BlockShape block_shape, ThreadLayout) +{ + using namespace cute; + + // Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D. + auto shape_S = shape(S); + Tensor C = make_identity_tensor(shape_S); + // Construct a predicate tensor which compares the coordinates with the original shape + Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); }); + + // Tile the input tensor into blocks + auto block_coord = make_coord(blockIdx.x, blockIdx.y); + Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + + // Construct a partitioning of the tile among threads with the given thread arrangement. + + // Concept: Tensor ThrLayout ThrIndex + Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); + Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); + Tensor thr_tile_P = local_partition(tile_P, ThreadLayout{}, threadIdx.x); + + // Copy from GMEM to GMEM using `thr_tile_P` to guard accesses. + copy_if(thr_tile_P, thr_tile_S, thr_tile_D); +} + +/// Vectorized copy kernel. +/// +/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation +/// has the precondition that pointers are aligned to the vector size. +/// +template +__global__ void copy_if_kernel_vectorized(TensorS S, TensorD D, BlockShape block_shape, Tiled_Copy tiled_copy) +{ + using namespace cute; + + // Construct a coordinate tensor whose elements are the coordinates used to access tensors S and D. + auto shape_S = shape(S); + Tensor C = make_identity_tensor(shape_S); + // Construct a predicate tensor which compares the coordinates with the original shape + Tensor P = cute::lazy::transform(C, [&](auto c) { return elem_less(c, shape_S); }); + + // Tile the input tensor into blocks + auto block_coord = make_coord(blockIdx.x, blockIdx.y); + Tensor tile_S = local_tile(S, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + Tensor tile_D = local_tile(D, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + Tensor tile_P = local_tile(P, block_shape, block_coord); // (BlockShape_M, BlockShape_N) + + // + // Construct a Tensor corresponding to each thread's slice. + // + ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CPY, CPY_M, CPY_N) + Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CPY, CPY_M, CPY_N) + Tensor thr_tile_P = thr_copy.partition_S(tile_P); // (CPY, CPY_M, CPY_N) + +#if 0 + // Copy from GMEM to GMEM + copy_if(tiled_copy, thr_tile_P, thr_tile_S, thr_tile_D); +#else + // make_fragment_like() constructs a tensor in RMEM with the same shape as thr_tile_S. + Tensor frag = make_fragment_like(thr_tile_S); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy_if(tiled_copy, thr_tile_P, thr_tile_S, frag); + copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D); +#endif +} + +/// Main function +int main(int argc, char** argv) +{ + // + // Given a 2D shape, perform an efficient copy + // + + using namespace cute; + using Element = float; + + // Define a tensor shape with dynamic extents (m, n) + auto tensor_shape = make_shape(528, 300); + + thrust::host_vector h_S(size(tensor_shape)); + thrust::host_vector h_D(size(tensor_shape)); + + // + // Initialize + // + + for (size_t i = 0; i < h_S.size(); ++i) { + h_S[i] = static_cast(i); + h_D[i] = Element{}; + } + + thrust::device_vector d_S = h_S; + thrust::device_vector d_D = h_D; + thrust::device_vector d_Zero = h_D; + + // + // Make tensors + // + + Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape)); + Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape)); + + // + // Partition + // + + // Define a statically sized block (M, N). + // + // Note, by convention, capital letters are used to represent static modes. + auto block_shape = make_shape(Int<128>{}, Int<64>{}); + + // Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile + // shape, and modes (m', n') correspond to the number of tiles. + // + // These will be used to determine the CUDA kernel grid dimensinos. + Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n') + + // Describes the layout of threads which is then replicated to tile 'block_shape.' + Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (ThrM, ThrN) + + // + // Determine grid and block dimensions + // + + dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n' + dim3 blockDim(size(thr_layout)); + + // + // Launch the kernel + // + + // copy_if() + copy_if_kernel<<< gridDim, blockDim >>>( + tensor_S, + tensor_D, + block_shape, + thr_layout); + + cudaError result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; + return -1; + } + + h_D = d_D; + + // + // Verification + // + + auto verify = [](thrust::host_vector const &S, thrust::host_vector const &D){ + + int32_t errors = 0; + int32_t const kErrorLimit = 10; + + if (S.size() != D.size()) { + return 1; + } + + for (size_t i = 0; i < D.size(); ++i) { + if (S[i] != D[i]) { + std::cerr << "Error. S[" << i << "]: " << S[i] << ", D[" << i << "]: " << D[i] << std::endl; + + if (++errors >= kErrorLimit) { + std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; + return errors; + } + } + } + + return errors; + }; + + if (verify(h_D, h_S)) { + return -1; + } else { + std::cout << "Success." << std::endl; + } + + thrust::copy(d_Zero.begin(), d_Zero.end(), d_D.begin()); + + // Construct a TiledCopy with a specific access pattern. + // This version uses a + // (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc), + // (2) Layout-of-Values that each thread will access. + + // Value arrangement per thread + Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx + + // Define `AccessType` which controls the size of the actual memory access instruction. + using CopyOp = UniversalCopy>; // A very specific access width copy instruction + //using CopyOp = UniversalCopy>; // A more generic type that supports many copy strategies + //using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs + + // A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element. + using Atom = Copy_Atom; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous data + // in GMEM. Alternative thread layouts are possible but may result in uncoalesced + // reads. Alternative value layouts are also possible, though incompatible layouts + // will result in compile time errors. + TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy + thr_layout, // thread layout (e.g. 32x4 Col-Major) + val_layout); // value layout (e.g. 4x1) + + // copy_if() with vectorization + copy_if_kernel_vectorized<<< gridDim, blockDim >>>( + tensor_S, + tensor_D, + block_shape, + tiled_copy); + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; + return -1; + } + + h_D = d_D; + + if (verify(h_D, h_S)) { + return -1; + } else { + std::cout << "Success." << std::endl; + } + return 0; +} + diff --git a/examples/python/CuTeDSL/ampere/smem_allocator.py b/examples/python/CuTeDSL/ampere/smem_allocator.py new file mode 100644 index 00000000..f9f5c1e0 --- /dev/null +++ b/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass +import torch +import numpy as np +from cutlass.cute.runtime import from_dlpack + +""" +A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL. + +This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL. +It shows various ways to allocate different data structures in shared memory: + +1. Struct allocation with natural and strict alignment +2. Raw memory block allocation with custom alignment +3. Array allocation with automatic alignment +4. Tensor allocation with layout specification + +The example includes: +- Shared storage struct with mixed alignment requirements +- Memory allocation patterns for different data types +- Tensor operations on allocated memory + +To run this example: + +.. code-block:: bash + + python examples/ampere/smem_allocator.py + +The example will allocate shared memory, perform tensor operations, and verify the results. +""" + + +@cute.struct +class complex: + real: cutlass.Float32 + imag: cutlass.Float32 + + +# SharedStorage size is 512, alignment is 128 +@cute.struct +class SharedStorage: + # struct elements with natural alignment + a: cute.struct.MemRange[cutlass.Float32, 32] # array + b: cutlass.Int64 # saclar + c: complex # nested struct + # struct elements with strict alignment + x: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, 32], + 128, + ] + y: cute.struct.Align[cutlass.Int32, 8] + z: cute.struct.Align[complex, 16] + + +@cute.kernel +def kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize + # Note: alignment of inital allocator base ptr is 1024 + allocator = cutlass.utils.SmemAllocator() + # base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory) + + # -- Allocate a struct -- + # Note: when specified alignment, max(alignment, alignof(struct)) will be applied + # reserves the section of struct in smem, elements in the struct can be accessed by ptr + struct_in_smem = allocator.allocate(SharedStorage) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct) + + # -- Allocate a block of memory -- + # reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr + section_in_smem = allocator.allocate(64, byte_alignment=128) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section) + + # -- Allocate an array -- + # reserves an int64 array of size 14 in smem, returns the array base ptr + array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array) + + # -- Allocate a tensor -- + # Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor + # Note: iterator swizzle with swizzle layout is currently not supported + layout = cute.make_layout((16, 2)) + tensor_in_smem = allocator.allocate_tensor( + element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None + ) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor) + + # ptr> + # ptr> + # ptr> + print(struct_in_smem.a.data_ptr()) + print(struct_in_smem.b) + print(struct_in_smem.c.real) + # ptr> + print(section_in_smem) + # ptr> + print(array_in_smem) + # tensor> o (16,4):(1,16)> + print(tensor_in_smem) + + # fill MemRange tensor in struct and copy to dst + a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4))) + a_tensor.fill(const_a) + cute.printf("cute.struct.MemRange: {}", a_tensor) + dst_a.store(a_tensor.load()) + + # convert block of smem to fill tensor and copy to dst + layout = cute.make_layout((8, 2)) + sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32) + sec_tensor = cute.make_tensor(sec_ptr, layout) + sec_tensor.fill(const_b) + cute.printf("block of memory: {}", sec_tensor) + dst_b.store(sec_tensor.load()) + + # fill allocated tensor in smem and copy to dst + tensor_in_smem.fill(const_c) + cute.printf("tensor in smem: {}", tensor_in_smem) + dst_c.store(tensor_in_smem.load()) + + +@cute.jit +def run_allocation_kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # additional size for the example, 64(section) + 112(array) + 128(tensor) < 384 + addtional_bytes = 384 + # Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes + kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch( + grid=(1, 1, 1), + block=(1, 1, 1), + smem=SharedStorage.size_in_bytes() + addtional_bytes, + ) + + +def veify_allocation_kernel(const_a, const_b, const_c): + dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda") + dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda") + dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda") + + run_allocation_kernel( + const_a, + from_dlpack(dst_a), + const_b, + from_dlpack(dst_b), + const_c, + from_dlpack(dst_c), + ) + + np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0]) + + +if __name__ == "__main__": + # prepare cuda context + cutlass.cuda.initialize_cuda_context() + # An example for shared memory allocation + const_a = 0.5 + const_b = 1.0 + const_c = 2.0 + veify_allocation_kernel(const_a, const_b, const_c) diff --git a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt new file mode 100644 index 00000000..6c02b198 --- /dev/null +++ b/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cmake_minimum_required(VERSION 3.15) +project(tensor) + +# Find Python +find_package(Python COMPONENTS Interpreter Development REQUIRED) + +# Get Python site-packages directory using Python +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import site; print(site.getsitepackages()[0])" + OUTPUT_VARIABLE Python_SITE_PACKAGES + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "Python site-packages directory: ${Python_SITE_PACKAGES}") + +# Add nanobind path to CMAKE_PREFIX_PATH +list(APPEND CMAKE_PREFIX_PATH ${Python_SITE_PACKAGES}/nanobind/cmake) + +# Find nanobind +find_package(nanobind REQUIRED) + +# Add the module +nanobind_add_module(tensor tensor.cpp) diff --git a/examples/python/CuTeDSL/cute/ffi/jit_argument.py b/examples/python/CuTeDSL/cute/ffi/jit_argument.py new file mode 100644 index 00000000..acdb42ef --- /dev/null +++ b/examples/python/CuTeDSL/cute/ffi/jit_argument.py @@ -0,0 +1,305 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations. + +This example demonstrates a basic approach to building customized interfaces as C-structures between user code +and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions +and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions. + +The C-structure is defined as: + +.. code-block:: c + + struct Tensor { + void *ptr; // Pointer to tensor data + int32_t shape[3]; // Tensor dimensions + int32_t strides[3]; // Memory strides for each dimension + }; + +The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer, +shape, and strides, enabling efficient data passing between different language boundaries. + +.. note:: + Future development may include automated code generation flows. +""" + +import cutlass +import cutlass.cute as cute + +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm +import cutlass._mlir.extras.types as T + + +class ExampleTensorValue(ir.Value): + """A wrapper class for tensor values in MLIR. + + This class extends ir.Value to provide convenient access to tensor data pointer, + shape, and strides through MLIR operations. + + :type: ir.Value + """ + + def __init__(self, v): + """Initialize a new TensorValue. + + :param v: The underlying MLIR value to wrap + :type v: ir.Value + """ + super().__init__(v) + + @property + def data_ptr(self, *, loc=None, ip=None): + """Get the data pointer from the tensor value. + + Extracts the data pointer (first field) from the LLVM struct value. + + :param loc: Optional location information for MLIR operations + :type loc: Optional[ir.Location] + :param ip: Optional insertion point for MLIR operations + :type ip: Optional[ir.InsertionPoint] + :return: An integer value representing the data pointer + :rtype: ir.Value + """ + # Extract the data pointer from the LLVM struct value + # The data pointer is the first field (index 0) in the struct + + # Use llvm.extractvalue to get the pointer field from the struct + ptr_val = llvm.extractvalue( + llvm.PointerType.get(), + self, + [0], # Extract the first field (index 0) + loc=loc, + ip=ip, + ) + + return cute.make_ptr(cutlass.Float32, ptr_val) + + @property + def shape(self): + """Get the shape of the tensor. + + Extracts the shape (second field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor dimensions + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the shape field from the LLVM struct value + # The shape is the second field (index 1) in the struct + shape_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [1], # Extract the second field (index 1) + ) + + # Extract each dimension from the shape struct + return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3)) + + @property + def stride(self): + """Get the strides of the tensor. + + Extracts the strides (third field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor strides + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the strides field from the LLVM struct value + # The strides are the third field (index 2) in the struct + strides_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [2], # Extract the third field (index 2) + ) + + # Extract each dimension from the strides struct + return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3)) + + +class ExampleTensor: + """A class representing a tensor with its data pointer, shape, and strides. + + This class provides a Python interface to create and manipulate tensor structures + that can be passed to CUTE JIT compiled functions. + + :ivar _c_struct_p: The C struct pointer for the tensor + :ivar _rank: The number of dimensions in the tensor + """ + + def __init__(self, c_struct_p, rank): + """Initialize a new Tensor. + + :param c_struct_p: The C struct pointer for the tensor + :type c_struct_p: int + :param rank: The number of dimensions in the tensor + :type rank: int + """ + self._c_struct_p = c_struct_p + self._rank = rank + + def __get_mlir_types__(self): + """Get the MLIR types for this tensor. + + Creates an LLVM structure type representing a C-structure with: + + .. code-block:: c + + struct Tensor { + void *ptr; + int32_t shape[3]; + int32_t strides[3]; + }; + + :return: A list containing the MLIR struct type + :rtype: list[llvm.StructType] + + Create an LLVM structure type that represents a C-structure like: + """ + + # Get the number of dimensions from the shape + ndim = self._rank + + # Create the pointer type (void*) + ptr_type = llvm.PointerType.get() + + # Create array types for shape and strides (int32_t[ndim]) + int32_type = ir.IntegerType.get_signless(32) + shape_type = llvm.StructType.get_literal([int32_type] * ndim) + strides_type = llvm.StructType.get_literal([int32_type] * ndim) + + # Create the structure type + struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type]) + + return [struct_type] + + def __new_from_mlir_values__(self, values): + """Create a new TensorValue from MLIR values. + + :param values: A list of MLIR values + :type values: list[ir.Value] + :return: A new TensorValue instance + :rtype: TensorValue + """ + return ExampleTensorValue(values[0]) + + def __c_pointers__(self): + """Get the C pointers for this tensor. + + :return: A list containing the C struct pointer + :rtype: list[int] + """ + return [self._c_struct_p] + + +@cute.jit +def foo(tensor): + """Example JIT function that prints tensor information. + + :param tensor: A Tensor instance to print information about + :type tensor: Tensor + """ + cute.printf("data_ptr: {}", tensor.data_ptr) + cute.printf("shape: {}", tensor.shape) + cute.printf("stride: {}", tensor.stride) + + mA = cute.make_tensor( + tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride) + ) + cute.print_tensor(mA) + + +import sys +import os +import subprocess +import shutil +import tempfile +import torch + + +def run_test(tmpdir=None): + # Skip cleanup if user provides tmpdir + cleanup = tmpdir is None + # Initialize temporary build directory + tmpdir = tmpdir or tempfile.mkdtemp() + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + + subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True) + subprocess.run(["cmake", "--build", tmpdir], check=True) + + sys.path.append(tmpdir) + + from tensor import make_tensor, pycapsule_get_pointer + + # Mock test tensor and corresponding C structure for this example + # In production, this may come from external library + x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4) + c_struct = make_tensor(x.data_ptr(), x.shape, x.stride()) + c_struct_p = pycapsule_get_pointer(c_struct) + + # Initialize tensor wrapper and compile test function + tensor = ExampleTensor(c_struct_p, len(x.shape)) + compiled_func = cute.compile(foo, tensor) + + # Benchmark pointer access performance + from time import time + + start = time() + # Measure performance of critical path pointer access + # get C pointers is on critical path to call JIT compiled function + for _ in range(1000): + tensor.__c_pointers__() + end = time() + print(f"__c_pointers__: {(end - start) * 1000} us") + + # Execute compiled function + compiled_func(tensor) + except Exception as e: + print(e) + finally: + if cleanup: + # Clean up the temporary directory + shutil.rmtree(tmpdir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Set temporary directory for building C modules" + ) + parser.add_argument( + "--tmp-dir", type=str, help="Temporary directory path for building C modules" + ) + args = parser.parse_args() + + run_test(args.tmp_dir) diff --git a/examples/python/CuTeDSL/cute/ffi/tensor.cpp b/examples/python/CuTeDSL/cute/ffi/tensor.cpp new file mode 100644 index 00000000..6aed2f2d --- /dev/null +++ b/examples/python/CuTeDSL/cute/ffi/tensor.cpp @@ -0,0 +1,82 @@ +// Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: + +// 1. Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. + +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. + +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. + +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include +#include + +namespace nb = nanobind; + +// Forward declaration of the MockTensor struct for testing only +struct MockTensor { + void *ptr; + struct { + int32_t shape[3]; + } shape; + + struct { + int32_t strides[3]; + } strides; +}; + +NB_MODULE(tensor, m) { + // create a tensor for testing + m.def("make_tensor", [](int64_t ptr, std::vector shape, + std::vector strides) { + auto *tensor = new MockTensor(); + tensor->ptr = reinterpret_cast(ptr); + + assert(shape.size() == 3 && "shape must have 3 elements"); + assert(strides.size() == 3 && "strides must have 3 elements"); + + for (size_t i = 0; i < shape.size(); i++) { + tensor->shape.shape[i] = shape[i]; + tensor->strides.strides[i] = strides[i]; + } + + return nb::steal(PyCapsule_New(tensor, "tensor", [](PyObject *capsule) { + auto n = PyCapsule_GetName(capsule); + if (void *p = PyCapsule_GetPointer(capsule, n)) { + delete reinterpret_cast(p); + } + })); + }); + + m.def( + "pycapsule_get_pointer", + [](nb::object &capsule) { + void *ptr = PyCapsule_GetPointer(capsule.ptr(), "tensor"); + if (!ptr) { + throw std::runtime_error("Invalid tensor capsule"); + } + return reinterpret_cast(ptr); + }, + "Get pointer from PyCapsule"); +} diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/hopper/dense_gemm.py new file mode 100644 index 00000000..dc9a0604 --- /dev/null +++ b/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -0,0 +1,1486 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type +import math +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.hopper_helpers as sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Hopper WGMMA instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape +is (1,1). The input, mma accumulator and output data type are set as fp16, fp32 +and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mnk 128,256,64 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +Constraints: +* Supported input data types: fp16, fp8 (e4m3fn, e5m2) +* For fp16 types, A and B must have the same data type +* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit +* Fp8 types only support k-major layout +* Only fp32 accumulation is supported in this example +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* CTA tile shape K must be 64 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +# ///////////////////////////////////////////////////////////////////////////// +# Helpers to parse args +# ///////////////////////////////////////////////////////////////////////////// +def parse_comma_separated_ints(s: str): + try: + return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.") + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(4096, 4096, 4096, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--tile_shape_mnk", + type=parse_comma_separated_ints, + choices=[(128, 128, 64), (128, 256, 64), (128, 64, 64), (64, 64, 64)], + default=(128, 128, 64), + help="Cta tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + choices=[(1, 1), (2, 1), (1, 2), (2, 2)], + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + if len(args.tile_shape_mnk) != 3: + parser.error("--tile_shape_mnk must contain exactly 3 values") + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + return args + + +# ///////////////////////////////////////////////////////////////////////////// +# Host setup and device kernel launch +# ///////////////////////////////////////////////////////////////////////////// + + +class HopperWgmmaGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Hopper GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mnk: Shape of the CTA tile (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing + :type cluster_shape_mnk: Tuple[int, int, int] + + :note: Data type requirements: + - For 16-bit types: A and B must have the same data type + - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit + - Float8 types only support k-major layout + + :note: Supported data types: + - Float16 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulation types: + - Float32 (for all floating point inputs) + + :note: Constraints: + - CTA tile M must be 64/128 + - CTA tile N must be 64/128/256 + - CTA tile K must be 64 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 4 + + Example: + >>> gemm = HopperWgmmaGemmKernel( + ... acc_dtype=cutlass.Float32, + ... tile_shape_mnk=(128, 256, 64), + ... cluster_shape_mnk=(1, 1, 1) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mnk: tuple[int, int, int], + ): + """ + Initializes the configuration for a Hopper dense GEMM kernel. + + This configuration includes data types for operands, tile shape, cluster configuration, + and thread layout. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mnk: Shape of the CTA tile (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing + :type cluster_shape_mnk: Tuple[int, int, int] + """ + + self.acc_dtype = acc_dtype + + self.cluster_shape_mnk = cluster_shape_mnk + self.mma_inst_shape_mn = None + self.tile_shape_mnk = tuple(tile_shape_mnk) + # For large tile size, using two warp groups is preferred because using only one warp + # group may result in register spill + self.atom_layout_mnk = ( + (2, 1, 1) + if tile_shape_mnk[0] > 64 and tile_shape_mnk[1] > 128 + else (1, 1, 1) + ) + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + self.mma_warp_groups = math.prod(self.atom_layout_mnk) + self.num_threads_per_warp_group = 128 + self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group + self.smem_capacity = sm90_utils.SMEM_CAPACITY["sm90"] + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + + # check the cta tile shape + if self.tile_shape_mnk[0] not in [64, 128]: + raise ValueError("CTA tile shape M must be 64/128") + if self.tile_shape_mnk[1] not in [64, 128, 256]: + raise ValueError("CTA tile shape N must be 64/128/256") + if self.tile_shape_mnk[2] not in [64]: + raise ValueError("CTA tile shape K must be 64") + + self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + self.num_mcast_ctas_a = self.cluster_shape_mnk[1] + self.num_mcast_ctas_b = self.cluster_shape_mnk[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + is_cooperative = self.atom_layout_mnk == (2, 1, 1) + self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.smem_capacity, + self.occupancy, + ) + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.epi_stage, + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = a.element_type + self.b_dtype = b.element_type + self.c_dtype = c.element_type + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr( + self.a_dtype.width == 16 and self.a_dtype != self.b_dtype + ): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError( + f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" + ) + if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError(f"a_dtype should be float16 or float8") + + self._setup_attributes() + + tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mnk[1], + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + b, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mnk[0], + ) + + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mnk) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sa: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sb: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=self.shared_storage.size_in_bytes(), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_c: TMA copy atom for C tensor + :type tma_atom_c: cute.CopyAtom + :param mC_mnl: Output tensor C + :type mC_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cta_layout_mnk: CTA layout + :type cta_layout_mnk: cute.Layout + :param a_smem_layout_staged: Shared memory layout for A + :type a_smem_layout_staged: cute.ComposedLayout + :param b_smem_layout_staged: Shared memory layout for B + :type b_smem_layout_staged: cute.ComposedLayout + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + """ + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + + # /////////////////////////////////////////////////////////////////////////////// + # Get cta/warp/thread idx + # /////////////////////////////////////////////////////////////////////////////// + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + cidx, cidy, _ = cute.arch.cluster_idx() + cdimx, cdimy, _ = cute.arch.cluster_dim() + cluster_id = cidx + cdimx * cidy + + # CTA Swizzle to promote L2 data reuse + group_size_m = 8 + s_shape = ( + (group_size_m, cdimx // group_size_m), + cdimy, + ) + s_stride = ((1, cdimy * group_size_m), group_size_m) + s_layout = cute.make_layout(s_shape, stride=s_stride) + num_reg_cids = cute.size(s_shape) + cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids) + + # Deal with the tail part + if cluster_id >= num_reg_cids: + tail_size_m = cdimx % group_size_m + tail_layout = cute.make_layout( + (tail_size_m, cdimy), stride=(1, tail_size_m) + ) + tail_cid = cluster_id - num_reg_cids + tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid) + cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m + cid_n = tail_cid_n + + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * self.cluster_shape_mnk[0] + bidx_in_cluster[0] + pid_n = cid_n * self.cluster_shape_mnk[1] + bidx_in_cluster[1] + + tile_coord_mnkl = (pid_m, pid_n, None, bidz) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + # /////////////////////////////////////////////////////////////////////////////// + # Get mcast mask + # /////////////////////////////////////////////////////////////////////////////// + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) + # ///////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # mbar arrays + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = utils.CooperativeGroup(utils.Agent.Thread) + # Set the consumer arrive count to the number of mcast size + consumer_arrive_cnt = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + mainloop_pipeline_consumer_group = utils.CooperativeGroup( + utils.Agent.Thread, consumer_arrive_cnt + ) + + mainloop_pipeline = utils.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cta_layout_mnk, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + + # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B + # /////////////////////////////////////////////////////////////////////////////// + sa = storage.sa.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sb = storage.sb.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sc_ptr = cute.recast_ptr( + sa.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype + ) + sc = cute.make_tensor(sc_ptr, epi_smem_layout_staged.outer) + + # /////////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////////// + # (bM, bK, loopK) + gA_mkl = cute.local_tile( + mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1) + ) + # (bN, bK, loopK) + gB_nkl = cute.local_tile( + mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1) + ) + # (bM, bN) + gC_mnl = cute.local_tile( + mC_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None) + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C + # ////////////////////////////////////////////////////////////////////////////// + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + warp_group_thread_layout = cute.make_layout( + self.mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)) + + tCgC = thr_mma.partition_C(gC_mnl) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition shared tensor for TMA load A/B + # ////////////////////////////////////////////////////////////////////////////// + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + sa_for_tma_partition = cute.group_modes(sa, 0, 2) + gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2) + tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + sa_for_tma_partition, + gA_for_tma_partition, + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + sb_for_tma_partition = cute.group_modes(sb, 0, 2) + gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2) + tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + sb_for_tma_partition, + gB_for_tma_partition, + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Make frangments + # ////////////////////////////////////////////////////////////////////////////// + tCsA = thr_mma.partition_A(sa) + tCsB = thr_mma.partition_B(sb) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB = tiled_mma.make_fragment_B(tCsB) + + acc_shape = tCgC.shape + accumulators = cute.make_fragment(acc_shape, self.acc_dtype) + + # /////////////////////////////////////////////////////////////////////////////// + # Cluster wait + # /////////////////////////////////////////////////////////////////////////////// + # cluster wait for barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_wait() + else: + cute.arch.sync_threads() + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch + # ///////////////////////////////////////////////////////////////////////////// + k_tile_cnt = cute.size(gA_mkl, mode=[2]) + prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0) + + mainloop_producer_state = utils.make_pipeline_state( + utils.PipelineUserType.Producer, self.ab_stage + ) + if warp_idx == 0: + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch TMA load + # ///////////////////////////////////////////////////////////////////////////// + for prefetch_idx in cutlass.range_dynamic(prefetch_k_tile_cnt, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # Prologue MMAs + # ///////////////////////////////////////////////////////////////////////////// + k_pipe_mmas = 1 + + mainloop_consumer_read_state = utils.make_pipeline_state( + utils.PipelineUserType.Consumer, self.ab_stage + ) + mainloop_consumer_release_state = utils.make_pipeline_state( + utils.PipelineUserType.Consumer, self.ab_stage + ) + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_tile in cutlass.range_dynamic(k_pipe_mmas, unroll=1): + # Wait for A/B buffer to be ready + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + + cute.nvgpu.warpgroup.fence() + for k_block_idx in range(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + mainloop_consumer_read_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + # ///////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # ///////////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range_dynamic(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for TMA copies to complete + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + # ///////////////////////////////////////////////////////////////////////////// + # WGMMA + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.fence() + for k_block_idx in range(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + + cute.nvgpu.warpgroup.commit_group() + # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + + mainloop_consumer_read_state.advance() + mainloop_consumer_release_state.advance() + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + # ///////////////////////////////////////////////////////////////////////////// + # TMA load + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0 and mainloop_producer_state.count < k_tile_cnt: + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # EPILOG + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.wait_group(0) + + if cute.size(self.cluster_shape_mnk) > 1: + # Wait for all threads in the cluster to finish, avoid early release of smem + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + # For cluster that has a single thread block, it might have more than one warp groups. + # Wait for all warp groups in the thread block to finish, because smem for tensor A in + # the mainloop is reused in the epilogue. + cute.arch.sync_threads() + + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_C_Atom, + ) + + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sc) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sc)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype) + size_tRS_rD = cute.size(tRS_rD) + + sepi_for_tma_partition = cute.group_modes(sc, 0, 2) + tcgc_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sepi_for_tma_partition, + tcgc_for_tma_partition, + ) + + epi_tile_num = cute.size(tcgc_for_tma_partition, mode=[1]) + epi_tile_shape = tcgc_for_tma_partition.shape[1] + + for epi_idx in cutlass.range_dynamic(epi_tile_num, unroll=epi_tile_num): + # Copy from accumulators to D registers + for epi_v in range(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + tRS_rD_out = cute.make_fragment_like(tRS_rD_layout, self.c_dtype) + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + # Copy from D registers to shared memory + epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3]) + cute.copy( + tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)] + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # barrier for sync + cute.arch.barrier() + + # Get the global memory coordinate for the current epi tile. + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(self.epi_stage - 1, read=True) + + cute.arch.barrier() + + return + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: tuple[int, int] + """ + + epi_stage = 4 + # epi_smem will reuse smem ab. + epi_bytes = 0 + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + + ab_stage = ( + (smem_capacity - occupancy * 1024) // occupancy + - mbar_helpers_bytes + - epi_bytes + ) // ab_bytes_per_stage + return ab_stage, epi_stage + + @staticmethod + def _sm90_compute_tile_shape_or_override( + tile_shape_mnk: tuple[int, int, int], + element_type: type[cutlass.Numeric], + is_cooperative: bool = False, + epi_tile_override: tuple[int, int] | None = None, + ) -> tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if is_cooperative: + tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + n_perf = 64 if element_type.width == 8 else 32 + tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: tuple[int, int, int], + epi_tile: tuple[int, int], + a_dtype: type[cutlass.Numeric], + a_layout: utils.LayoutEnum, + b_dtype: type[cutlass.Numeric], + b_layout: utils.LayoutEnum, + ab_stage: int, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + epi_stage: int, + ) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]: + """Create shared memory layouts for A, B, and C tensors. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout enum for matrix A + :type a_layout: utils.LayoutEnum + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout enum for matrix B + :type b_layout: utils.LayoutEnum + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param c_dtype: Data type for output matrix C + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum for the output matrix C + :type c_layout: utils.LayoutEnum + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + + a_is_k_major = ( + a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + b_is_k_major = ( + b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + a_layout, + a_dtype, + a_major_mode_size, + ), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + + b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + b_layout, + b_dtype, + b_major_mode_size, + ), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + c_smem_shape = epi_tile + c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0] + c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + c_layout, + c_dtype, + c_major_mode_size, + ), + c_dtype, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(c_smem_shape, epi_stage), + order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + ) + + return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged + + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mnk: tuple[int, int, int], + ) -> tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mnk: Shape of each cluster in M, N, K dimensions. + :type cluster_shape_mnk: tuple[int, int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + c_shape = (tile_shape_mnk[0], tile_shape_mnk[1]) + gc = cute.zipped_divide(c, tiler=c_shape) + clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnk) + grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnk)) + return grid + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for C tensor storage. + + :param tensor_c: Output tensor C + :type tensor_c: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + c_cta_v_layout = cute.composition( + cute.make_identity_layout(tensor_c.shape), epi_tile + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tma_tile_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + c_cta_v_layout, + ) + + return tma_atom_c, tma_tensor_c + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + # tested a_dtype + if a_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested b_dtype + if b_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested acc_dtype + if acc_dtype != cutlass.Float32: + is_valid = False + # tested c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) + if a_dtype.width != b_dtype.width: + is_valid = False + + # for Float8 types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or ( + b_dtype.width == 8 and b_major != "k" + ): + is_valid = False + + return is_valid + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + + print(f"Running Hopper Dense GEMM with:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Tile Shape: {tile_shape_mnk}, Cluster Shape: {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + + # Unpack parameters + m, n, k, l = mnkl + cluster_shape_mnk = (*cluster_shape_mn, 1) + + # Skip unsupported types + if not HopperWgmmaGemmKernel.is_valid_dtypes( + a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major + ): + raise TypeError( + f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {c_dtype}, {a_major=}, {b_major=}" + ) + + # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else : (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass.torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + + gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mnk, cluster_shape_mnk) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # compile gemm kernel + compiled_gemm = cute.compile(gemm, mA, mB, mC, stream) + # execution + compiled_gemm(mA, mB, mC, stream) + + torch.cuda.synchronize() + + # Ref check + ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() + + if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2): + # m major: (l, n, m) -> (m, n, l) + # k major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03) + + +if __name__ == "__main__": + args = parse_arguments() + run_dense_gemm( + args.mnkl, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.tile_shape_mnk, + args.cluster_shape_mn, + args.tolerance, + ) + print("PASS") diff --git a/examples/python/CuTeDSL/notebooks/hello_world.ipynb b/examples/python/CuTeDSL/notebooks/hello_world.ipynb index 47719ae6..e722d828 100644 --- a/examples/python/CuTeDSL/notebooks/hello_world.ipynb +++ b/examples/python/CuTeDSL/notebooks/hello_world.ipynb @@ -83,11 +83,6 @@ "\n", " # Print hello world from host code\n", " cute.printf(\"hello world\")\n", - " \n", - " # Initialize CUDA context for launching a kernel with error checking\n", - " # We make context initialization explicit to allow users to control the context creation \n", - " # and avoid potential issues with multiple contexts\n", - " cutlass.cuda.initialize_cuda_context()\n", "\n", " # Launch kernel\n", " kernel().launch(\n", @@ -129,6 +124,11 @@ } ], "source": [ + "# Initialize CUDA context for launching a kernel with error checking\n", + "# We make context initialization explicit to allow users to control the context creation \n", + "# and avoid potential issues with multiple contexts\n", + "cutlass.cuda.initialize_cuda_context()\n", + "\n", "# Method 1: Just-In-Time (JIT) compilation - compiles and runs the code immediately\n", "print(\"Running hello_world()...\")\n", "hello_world()\n", @@ -136,6 +136,7 @@ "# Method 2: Compile first (useful if you want to run the same code multiple times)\n", "print(\"Compiling...\")\n", "hello_world_compiled = cute.compile(hello_world)\n", + "\n", "# Run the pre-compiled version\n", "print(\"Running compiled version...\")\n", "hello_world_compiled()" diff --git a/include/cute/algorithm/axpby.hpp b/include/cute/algorithm/axpby.hpp index 8c54f46d..60d5b468 100644 --- a/include/cute/algorithm/axpby.hpp +++ b/include/cute/algorithm/axpby.hpp @@ -33,7 +33,6 @@ #include #include -#include namespace cute { @@ -45,7 +44,7 @@ template + class PrdTensor = constant_fn> CUTE_HOST_DEVICE void axpby(Alpha const& alpha, @@ -64,7 +63,7 @@ template + class PrdTensor = constant_fn> CUTE_HOST_DEVICE void axpby(Alpha const& alpha, diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index 8a23bf60..2653916e 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -36,7 +36,6 @@ #include // cute::Swizzle #include // cute::get_nonswizzle_portion #include // cute::Tensor -#include #include #include diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 9700b3f2..d05b170f 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -32,7 +32,6 @@ #include // CUTE_HOST_DEVICE #include // cute::Tensor -#include // cute::TrivialPredTensor #include // cute::Copy_Atom namespace cute @@ -66,10 +65,45 @@ copy_if(PrdTensor const& pred, // copy_if -- Predicated CopyAtom // +// Predicate Tensor is an Actual Tensor +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + Tensor const& prd, // ([V],Rest...) + Tensor const& src, // ( V, Rest...) + Tensor & dst) // ( V, Rest...) +{ + if constexpr (PrdLayout::rank == SrcLayout::rank - 1) { + // Back-compat ONLY -- Delete? + copy_if(copy_atom, make_tensor(prd.data(), prepend(prd.layout(), Layout<_1,_0>{})), src, dst); + } else { + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + static_assert(SrcLayout::rank == PrdLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(prd, src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor prd_v = group_modes<1,R>(prd); + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(prd_v(_,i), src_v(_,i), dst_v(_,i)); + } + } + } +} + template +[[deprecated("Use a bool-tensor or transform-tensor as predication.")]] CUTE_HOST_DEVICE void copy_if(Copy_Atom const& copy_atom, @@ -77,33 +111,14 @@ copy_if(Copy_Atom const& copy_atom, Tensor const& src, // (V,Rest...) Tensor & dst) // (V,Rest...) { - static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); - auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); - - if constexpr (SrcLayout::rank == 1) { // Dispatch the copy - if constexpr (has_with_bool) { - copy_atom.with(pred()).call(src, dst); - } else { - if (pred()) { copy_atom.call(src, dst); } - } - } else { // Loop over all but the first mode - constexpr int R = SrcLayout::rank; - Tensor src_v = group_modes<1,R>(src); - Tensor dst_v = group_modes<1,R>(dst); - CUTE_UNROLL - for (int i = 0; i < size<1>(dst_v); ++i) { - if constexpr (has_with_bool) { - copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); - } else { - if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); } - } - } - } + Tensor tpred = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), pred); + return copy_if(copy_atom, tpred, src, dst); } // // copy_if -- AutoCopyAsync // + template @@ -159,7 +174,7 @@ copy(AutoCopyAsync const& cpy, Tensor const& src, // (V,Rest...) Tensor & dst) // (V,Rest...) { - copy_if(cpy, TrivialPredTensor{}, src, dst); + copy_if(cpy, constant_fn{}, src, dst); } // @@ -202,7 +217,7 @@ copy(Copy_Atom const& copy_atom, Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) - CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V( size<1>(src_c) == size<1>(dst_c)); CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); @@ -224,7 +239,7 @@ copy(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////// // Specialization for AutoVectorizingCopyAssumedAlignment -template CUTE_HOST_DEVICE @@ -234,23 +249,30 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, Tensor & dst) { constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); - constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); - static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit copy!"); - constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit write!"); - if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) { - // If more than one element vectorizes to 8bits or more, then recast and copy - using VecType = uint_bit_t; - // Preserve volatility - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; + if constexpr (common_elem > 1) + { + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); - // Recast - Tensor src_v = recast(src); - Tensor dst_v = recast(dst); - return copy_if(TrivialPredTensor{}, src_v, dst_v); + if constexpr ((vec_bits % 8) == 0) + { + // If more than one element vectorizes to 8bits or more, then recast and copy + using VecType = uint_bit_t; + // Preserve volatility + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; + + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); + return copy_if(constant_fn{}, src_v, dst_v); + } else { + return copy_if(constant_fn{}, src, dst); + } } else { - return copy_if(TrivialPredTensor{}, src, dst); + return copy_if(constant_fn{}, src, dst); } } @@ -277,7 +299,7 @@ copy(AutoFilter const& copy_op, Tensor src_n = zipped_divide(src, dst_null); CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); - CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous race-condition detected."); copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); } else { @@ -335,6 +357,18 @@ copy(Copy_Atom, Args...> con return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); } +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom>, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) template , CA_Args...> const& atom, - Tensor const& src, - Tensor & dst) + Tensor const& src, + Tensor & dst) { return copy(static_cast const&>(atom), src, dst); } diff --git a/include/cute/algorithm/prefetch.hpp b/include/cute/algorithm/prefetch.hpp index 16bbec51..265da128 100644 --- a/include/cute/algorithm/prefetch.hpp +++ b/include/cute/algorithm/prefetch.hpp @@ -90,18 +90,19 @@ constexpr bool has_prefetch> = true; } // end namespace detail -template CUTE_HOST_DEVICE void -prefetch(Copy_Atom, CA_Args...> const& atom, - Tensor const& src) +prefetch(Copy_Atom, CopyType> const& atom, + Tensor const& src) { if constexpr (detail::has_prefetch) { using Prefetch_Traits = Copy_Traits; - using Prefetch_Atom = Copy_Atom; + using Prefetch_Atom = Copy_Atom; Prefetch_Atom prefetch_atom{atom}; - auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + //auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + Tensor dst = make_tensor(make_smem_ptr(nullptr), shape(src)); return copy(prefetch_atom, src, dst); } else { return prefetch(src); diff --git a/include/cute/algorithm/tensor_algorithms.hpp b/include/cute/algorithm/tensor_algorithms.hpp index 6359a55e..f47becb1 100644 --- a/include/cute/algorithm/tensor_algorithms.hpp +++ b/include/cute/algorithm/tensor_algorithms.hpp @@ -163,4 +163,16 @@ transform(Tensor const& tensor_in1, return transform(tensor_in1, tensor_in2, tensor_out, op); } +namespace lazy { + +template +CUTE_HOST_DEVICE constexpr +auto +transform(cute::Tensor const& t, Fn const& fn) +{ + return cute::make_tensor(cute::make_transform_iter(fn, t.data()), t.layout()); +} + +} // end namespace lazy + } // end namespace cute diff --git a/include/cute/algorithm/tensor_reduce.hpp b/include/cute/algorithm/tensor_reduce.hpp new file mode 100644 index 00000000..a6f13735 --- /dev/null +++ b/include/cute/algorithm/tensor_reduce.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// Reduce @src tensor using binary reduction operator @op and initial value @init and return a scalar. +template +CUTE_HOST_DEVICE constexpr +T +reduce(Tensor const& src, T init, BinaryOp op = {}) +{ + for (auto i = 0; i < size(src); ++i) { + init = op(init, src(i)); + } + return init; +} + +// Reduce @src tensor RedMode using binary reduction operator @op and store the result in @dst tensor +// for each index in @dst/BatchMode. +// @pre @src tensor has rank 2 +// @pre size of @src batch mode is equal to size of @dst batch mode +template +CUTE_HOST_DEVICE constexpr +void +batch_reduce(Tensor const& src, // (RedMode, BatchMode) + Tensor & dst, // (BatchMode) + BinaryOp op = {}) +{ + // Precondition + CUTE_STATIC_ASSERT_V(rank(src) == Int<2>{}); + assert(size<1>(src) == size(dst)); + + for (int i = 0; i < size(dst); ++i) { + dst(i) = reduce(src(_,i), dst(i), op); + } +} + + +// Reduce @src tensor along selected modes specified in @target_profile using binary reduction operator @op +// and store the result in @dst tensor. @target_profile is a tuple where '_' indicates modes to keep and +// integers indicates modes to reduce. +// @pre @target_profile is compatible with @src layout +template +CUTE_HOST_DEVICE constexpr +void +logical_reduce(Tensor const& src, + Tensor & dst, + TargetProfile const& target_profile, + BinaryOp op = {}) +{ + // Precondition + assert(compatible(target_profile, shape(src))); + + auto diced_layout = dice(target_profile, src.layout()); + auto sliced_layout = slice(target_profile, src.layout()); + + auto red_mode = conditional_return{}>(Layout<_1,_0>{}, diced_layout); + auto batch_mode = conditional_return{}>(Layout<_1,_0>{}, sliced_layout); + + auto src_tensor = make_tensor(src.data(), make_layout(red_mode, batch_mode)); + + batch_reduce(src_tensor, dst, op); +} + +} // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 2d4eac73..5c455cc3 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -123,6 +123,56 @@ struct Copy_Atom, CopyInternalType> { return call(src, dst); } + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor & dst) const + { + static_assert(PLayout::rank == 1, "Expected rank-1 prd tensor"); + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + Traits const& traits = static_cast(*this); + auto has_with_bool = cute::is_valid([](auto t)->void_t{}, traits); + if constexpr (has_with_bool) { + copy_unpack(traits.with(prd(Int<0>{})), src, dst); + } else { + if (prd(Int<0>{})) { copy_unpack(traits, src, dst); } + } + } else if constexpr (is_tuple::value && + is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy_if(*this, tensor<0>(prd), tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor && dst) const + { + return call(prd, src, dst); + } }; // @@ -733,13 +783,13 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #include #include -#include +#include // Config #if (__CUDACC_VER_MAJOR__ >= 12) # define CUTE_COPY_ATOM_TMA_SM90_ENABLED -# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED #endif diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index 701fe135..740cc1b7 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -235,6 +235,63 @@ raw_pointer_cast(counting_iterator const& x) { return x.n_; } +// +// transform_iterator +// + +template +struct transform_iter +{ + using iterator = Iter; + // using reference = typename iterator_traits::reference; + // using element_type = typename iterator_traits::element_type; + // using value_type = typename iterator_traits::value_type; + + Fn fn_; + iterator ptr_; + + CUTE_HOST_DEVICE constexpr + transform_iter(Fn fn, iterator ptr = {}) : fn_(fn), ptr_(ptr) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) operator*() const { return fn_(*ptr_); } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator[](Index const& i) const { return fn_(ptr_[i]); } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Index const& i) const { return transform_iter{fn_, ptr_+i}; } + + template + CUTE_HOST_DEVICE constexpr + friend bool operator==(transform_iter const& x, transform_iter const& y) { return x.ptr_ == y.ptr_; } + template + CUTE_HOST_DEVICE constexpr + friend bool operator!=(transform_iter const& x, transform_iter const& y) { return x.ptr_ != y.ptr_; } + template + CUTE_HOST_DEVICE constexpr + friend bool operator< (transform_iter const& x, transform_iter const& y) { return x.ptr_ < y.ptr_; } + template + CUTE_HOST_DEVICE constexpr + friend bool operator<=(transform_iter const& x, transform_iter const& y) { return x.ptr_ <= y.ptr_; } + template + CUTE_HOST_DEVICE constexpr + friend bool operator> (transform_iter const& x, transform_iter const& y) { return x.ptr_ > y.ptr_; } + template + CUTE_HOST_DEVICE constexpr + friend bool operator>=(transform_iter const& x, transform_iter const& y) { return x.ptr_ >= y.ptr_; } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_transform_iter(Fn const& fn, Iterator const& ptr) +{ + return transform_iter(fn,ptr); +} + // // Display utilities // @@ -251,12 +308,24 @@ CUTE_HOST_DEVICE void print(counting_iterator ptr) printf("counting_iter("); print(ptr.n_); printf(")"); } +template +CUTE_HOST_DEVICE void print(transform_iter ptr) +{ + printf("trans_"); print(ptr.ptr_); +} + #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) { return os << "counting_iter(" << ptr.n_ << ")"; } + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, transform_iter ptr) +{ + return os << "trans_" << ptr.ptr_; +} #endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h index e7defb5d..f1e0200f 100644 --- a/include/cutlass/arch/grid_dependency_control.h +++ b/include/cutlass/arch/grid_dependency_control.h @@ -41,7 +41,8 @@ #include "cutlass/gemm/dispatch_policy.hpp" #ifndef CUTLASS_GDC_ENABLED - #if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ + #if (CUDA_BARRIER_ENABLED && \ + defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ __CUDACC_VER_MAJOR__ >= 12 && \ defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) #define CUTLASS_GDC_ENABLED diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index 278f69f9..d3c541c3 100644 --- a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -43,7 +43,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/trace.h" diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index be490f97..11eefed9 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -32,7 +32,6 @@ #include "cutlass/cutlass.h" -#include "cute/tensor_predicate.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cute/arch/copy_sm90.hpp" #include "cute/atom/mma_atom.hpp" @@ -103,7 +102,7 @@ struct CollectiveConv< using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; - + using ProblemShape = ConvProblemShape; static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); @@ -332,7 +331,7 @@ public: TmaTransactionBytes }; } - + template static bool can_implement( @@ -409,7 +408,7 @@ public: if constexpr (ConvOp == conv::Operator::kWgrad) { #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) std::ostringstream os; -#endif +#endif const auto & input_shape = problem_shape.shape_A; const auto & input_stride = problem_shape.stride_A; @@ -431,11 +430,11 @@ public: << "\n input_shape: " << input_shape << "\n input_stride: " << input_stride << "\n"; -#endif +#endif CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST(os.str()); -#endif +#endif return false; } @@ -464,7 +463,7 @@ public: CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST(os.str()); -#endif +#endif return false; } } @@ -516,8 +515,8 @@ public: /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) /// The rest of the tensors can be specified as needed by this collective. - /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with - /// StrideA and StrideB set up for TMA + /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with + /// StrideA and StrideB set up for TMA template CUTLASS_DEVICE auto load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 504575ad..d60469f4 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -303,6 +303,16 @@ public: dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || + ConvKernel::ArchTag::kMinComputeCapability == 101) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + void* kernel_params[] = {¶ms}; if constexpr (kEnableCudaHostAdapter) { // @@ -313,6 +323,7 @@ public: launch_result = cuda_adapter->launch(grid, cluster, + fallback_cluster, block, smem_size, stream, @@ -338,6 +349,20 @@ public: grid, cluster, block, smem_size, stream, kernel, kernel_params); } } + else { + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || + ConvKernel::ArchTag::kMinComputeCapability == 101) { + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params); + } + } } } else { diff --git a/include/cutlass/detail/blockwise_scale_layout.hpp b/include/cutlass/detail/blockwise_scale_layout.hpp index 2d545bbd..c05498c5 100644 --- a/include/cutlass/detail/blockwise_scale_layout.hpp +++ b/include/cutlass/detail/blockwise_scale_layout.hpp @@ -48,7 +48,7 @@ namespace cutlass::detail{ using namespace cute; template -struct Sm100BlockwiseScaleConfig { +struct Sm1xxBlockwiseScaleConfig { using ShapeSFA = Shape, int32_t>, Shape, int32_t>, int32_t>; using ShapeSFB = Shape, int32_t>, Shape, int32_t>, int32_t>; @@ -271,7 +271,18 @@ struct RuntimeBlockwiseScaleConfig { // Sm90 only supports MN major for SFA and SFB for now template -using Sm90BlockwiseScaleConfig = Sm100BlockwiseScaleConfig; +using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; + +template +using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; + +template +using Sm120BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; + +template +constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} template constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { @@ -279,8 +290,8 @@ constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { } template -constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { - return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +constexpr auto sm120_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm120BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 562adc65..e1c1bd6c 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -371,11 +371,14 @@ template < constexpr int get_input_alignment_bits() { if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 4) { + // 16U4 format: The inner tensor size dimension should be multiple of 64B. return 64 * 8; } else if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 6) { + // 16U6 format : The inner tensor size dimension must be a multiple of 96B. return 96 * 8; } + // TMA 16B alignment requirement return 128; } @@ -383,12 +386,11 @@ get_input_alignment_bits() { template constexpr int get_output_alignment_bits() { - if constexpr (sizeof_bits::value == 6) { - // U6 format : The inner tensor size dimension must be a multiple of 96B. + // 16U6 format : The inner tensor size dimension must be a multiple of 96B. return 96 * 8; } - + // TMA 16B alignment requirement return 128; } diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 176b1f25..b7102165 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -981,6 +981,9 @@ private: static constexpr bool Is2SmMma = is_base_of_v; static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); + // C/D should meet TMA alignment requirement if not void + static_assert(detail::is_aligned(), + "C/D Should meet TMA alignment requirement\n"); static constexpr bool DisableDestination = cute::is_void_v; using ElementD = cute::conditional_t,ElementD_>; // prevents void ref breakages diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index ef0d7c4b..a94d9457 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -293,6 +293,9 @@ template < class DispatchPolicy > struct Sm90TmaBuilderImpl { + // C/D should meet TMA alignment requirement if not void + static_assert(detail::is_aligned(), + "C/D Should meet TMA alignment requirement\n"); // Passing void D disables destination store + smem allocation using ElementD = cute::conditional_t, fusion::get_element_aux_t, ElementD_>; diff --git a/include/cutlass/epilogue/collective/builders/sm90_common.inl b/include/cutlass/epilogue/collective/builders/sm90_common.inl index c0a90396..4c259aa8 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_common.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_common.inl @@ -91,6 +91,14 @@ sm90_get_smem_load_op_for_source() { } } +// C/D should meet TMA alignment requirement if not void +template +constexpr bool +is_aligned() { + return (cute::is_void_v || (cute::sizeof_bits_v * AlignmentC) % cutlass::detail::get_output_alignment_bits() == 0) && + (cute::is_void_v || (cute::sizeof_bits_v * AlignmentD) % cutlass::detail::get_output_alignment_bits() == 0); +} + /////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::collective::detail diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 2c72c301..94e43baf 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -217,6 +217,16 @@ struct IsThreadEpilogueOpWithActivation +struct IsThreadEpilogueOpWithPerChannelScaled { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaled > { + static constexpr bool value = ThreadEpilogueOp::IsPerRowScaleSupported || ThreadEpilogueOp::IsPerColScaleSupported; +}; + template struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index e32cdfa4..f5e8fb50 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -57,7 +57,7 @@ struct IsDefaultFusionOp { }; template< - class ElementD, class ElementCompute, + class ElementD, class ElementCompute, class ElementC, FloatRoundStyle RoundStyle > struct IsDefaultFusionOp< @@ -69,7 +69,7 @@ struct IsDefaultFusionOp< template< class ElementOutput, int Count, class ElementAccumulator, - class ElementCompute, epilogue::thread::ScaleType::Kind Scale, + class ElementCompute, epilogue::thread::ScaleType::Kind Scale, FloatRoundStyle Round, class ElementSource > struct IsDefaultFusionOp< @@ -133,7 +133,7 @@ public: constexpr static int ThreadCount = 128; constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; - + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; constexpr static uint32_t TmaTransactionBytes = 0; @@ -240,7 +240,7 @@ public: Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) - + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) @@ -250,7 +250,7 @@ public: Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); - + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); @@ -325,7 +325,7 @@ public: copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } // source is not needed, avoid load - else + else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { @@ -382,7 +382,7 @@ public: auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) @@ -498,7 +498,7 @@ public: // Constructor and Data Members // CUTLASS_DEVICE - CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) : fusion_callbacks(params_.thread, shared_tensors.thread) , smem_buffer_ptr(shared_tensors.buffer.data()) , params(params_) {}; @@ -506,7 +506,7 @@ public: protected: FusionCallbacks fusion_callbacks; uint8_t* smem_buffer_ptr; - Params const& params; + Params const& params; public: @@ -543,7 +543,7 @@ public: can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { - + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); if (!fusion_implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); @@ -636,7 +636,7 @@ public: Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ - problem_shape_mnkl, + problem_shape_mnkl, cta_tile_mnk, cta_coord_mnkl, int(0), @@ -693,20 +693,17 @@ public: } Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); cst_callbacks.begin_loop(epi_m, epi_n); if constexpr (not cute::is_void_v) { if (is_C_load_needed) { using CVecType = uint_bit_t>; - Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int{}))); - - auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { - return elem_less(tTR_cC_frag(coords...), problem_shape_mnl); - }; Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); Tensor tTR_rC_frg = recast(coalesce(tCrC)); - copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg); + Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); + copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); } } @@ -717,7 +714,7 @@ public: Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); - + // After the last tmem load, signal that tmem buffer is consumed and empty if (do_acc_release) { cutlass::arch::fence_view_async_tmem_load(); @@ -737,16 +734,11 @@ public: cst_callbacks.end_loop(epi_m, epi_n); - - Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int{}))); - auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { - return elem_less(tTR_cD_frag(coords...), problem_shape_mnl); - }; - using VecType = uint_bit_t>; Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); - copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg); + Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); + copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); } // for epi_m } // for epi_n diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 81b08fcb..c2b8d84d 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -59,7 +59,7 @@ template < > class Epilogue { static_assert(cute::is_same_v || - cute::is_same_v, + cute::is_same_v, "Could not find an epilogue specialization."); }; @@ -141,7 +141,7 @@ public: ElementScalar const* beta_ptr = nullptr; ElementBias const* bias_ptr = nullptr; StrideBias dBias{}; - }; + }; template struct ThreadEpilogueOpArguments< @@ -202,7 +202,7 @@ public: to_underlying_arguments( [[maybe_unused]] ProblemShape const& _, Arguments const& args, - [[maybe_unused]] void* workspace) { + [[maybe_unused]] void* workspace) { typename ThreadEpilogueOp::Params thread_op_args; thread_op_args.alpha = args.thread.alpha; thread_op_args.beta = args.thread.beta; @@ -317,7 +317,7 @@ public: Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - + // Construct a tensor in SMEM that we can partition for rearranging data SharedStorage& storage = *reinterpret_cast(smem_buf); Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) @@ -389,10 +389,10 @@ public: Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); + Tensor tSR_pD_flt = cute::lazy::transform(tSR_cD_flt, [&](auto const& c){ return elem_less(c, take<0,2>(residue_mnk)); }); // Step 0. Copy Bias from GMEM to fragment - auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); }; - copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt); + copy_if(tSR_pD_flt, tSR_gBias_flt, tSR_rBias_flt); } } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 0e715a78..6aec0e83 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -560,18 +560,18 @@ struct Sm90TreeVisitor< Tensor tC_rAux_vec = recast(tC_rAux); Tensor tC_gAux_vec = recast(tC_gAux); Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; - copy_if(predicate_fn, tC_rAux_vec, tC_gAux_vec); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); } // sub-byte vectorization, must serialize threads else { // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) int lane_idx = canonical_lane_idx(); - auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; + Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); CUTLASS_PRAGMA_NO_UNROLL for (int i = 0; i < NumThreadsPerWarp; ++i) { if (lane_idx == i) { - copy_if(predicate_fn, tC_rAux, tC_gAux); + copy_if(tC_pAux, tC_rAux, tC_gAux); } __syncwarp(); } @@ -719,12 +719,12 @@ struct Sm90AuxLoad< Tensor tC_gAux_vec = recast(tC_gAux); Tensor tC_rAux_vec = recast(tC_rAux); Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); - auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; - copy_if(predicate_fn, tC_gAux_vec, tC_rAux_vec); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); } else { - auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; - copy_if(predicate_fn, tC_gAux, tC_rAux); + Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux, tC_gAux, tC_rAux); } } } @@ -738,8 +738,8 @@ struct Sm90AuxLoad< } } - auto predicate_fn = [&] (auto&&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_tC_cAux); }; - copy_if(predicate_fn, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); } } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index cd470f84..72afd1e5 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -449,7 +449,7 @@ template < bool EnableNullptr > struct Sm90AuxLoad< - 0, EpilogueTile, Element, LayoutOrStrideMNL, + 0, EpilogueTile, Element, LayoutOrStrideMNL, SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr > { using ElementAux = Element; @@ -496,7 +496,7 @@ struct Sm90AuxLoad< CUTLASS_HOST_DEVICE Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; CUTLASS_DEVICE bool @@ -533,7 +533,7 @@ struct Sm90AuxLoad< tC_cAux(cute::forward(tC_cAux)), problem_shape_mnl(problem_shape_mnl), params_ptr(params_ptr) {} - + GTensorG2R tC_gAux; RTensor tC_rAux; CTensorG2R tC_cAux; @@ -551,17 +551,13 @@ struct Sm90AuxLoad< constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; constexpr int V = cute::min(Alignment, size(MCL)); - Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { - return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); - }; + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec); + copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); } template @@ -647,7 +643,7 @@ struct Sm90ScalarBroadcast { can_implement(ProblemShape const& problem_shape, Arguments const& args) { return true; } - + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -674,11 +670,11 @@ struct Sm90ScalarBroadcast { // This must be called after update_scalar is called CUTLASS_DEVICE bool is_zero() const { - if (get<2>(params_ptr->dScalar[0]) == 0) { + if (get<2>(params_ptr->dScalar[0]) == 0) { // Only 1 batch return scalar == Element(0); } - else { + else { // multiple batch if (valid_scalar == false) { // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. @@ -761,7 +757,7 @@ private: if (params_ptr->scalar_ptrs[0] != nullptr) { scalar = params_ptr->scalar_ptrs[0][l_offset]; - } + } else { // batch stride is ignored for nullptr fallback scalar = params_ptr->scalars[0]; @@ -774,7 +770,7 @@ private: if (params_ptr->scalar_ptrs[i] != nullptr) { int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } + } else { // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); @@ -826,7 +822,7 @@ struct Sm90ScalarBroadcastPtrArray { can_implement(ProblemShape const& problem_shape, Arguments const& args) { return true; } - + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -946,7 +942,7 @@ private: if (params_ptr->scalar_ptrs[i] != nullptr) { int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); - } + } else { // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); @@ -992,7 +988,7 @@ struct Sm90RowBroadcast { static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); - struct SharedStorage { + struct SharedStorage { array_aligned(CtaTileShapeMNK{})> smem; }; @@ -1078,8 +1074,8 @@ struct Sm90RowBroadcast { struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE ConsumerStoreCallbacks( - GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, - GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, Residue residue_cRow_, Params const& params_) : tGS_gRow(tGS_gRow_) @@ -1098,8 +1094,8 @@ struct Sm90RowBroadcast { Tiled_G2S tiled_G2S; SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Residue residue_cRow; // (m, n) Params const& params; @@ -1113,7 +1109,7 @@ struct Sm90RowBroadcast { for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { - continue; // OOB of SMEM, + continue; // OOB of SMEM, } if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) { tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load @@ -1201,18 +1197,18 @@ struct Sm90RowBroadcast { } Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), + Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem auto tiled_g2s = make_tiled_copy(Copy_Atom{}, - Layout< Shape<_1, ThreadCount>, - Stride<_0, _1>>{}, - Layout<_1>{}); + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); Tensor tGS_gRow = thr_g2s.partition_S(gRow); Tensor tGS_sRow = thr_g2s.partition_D(sRow); - //// G2S: Coord + //// G2S: Coord Tensor tGS_cRow = thr_g2s.partition_S(args.cD); //// S2R: Smem to Reg @@ -1220,11 +1216,11 @@ struct Sm90RowBroadcast { Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) return ConsumerStoreCallbacks( - tGS_gRow, - tGS_sRow, - tGS_cRow, tiled_g2s, - tSR_sRow, - tSR_rRow, + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, args.residue_cD, params); } @@ -1378,12 +1374,12 @@ struct Sm90ColBroadcast { Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); - auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_vec(coords...), residue_tCcCol); }; - copy_if(pred_fn, tCgCol_vec, tCrCol_vec); + Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); + copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec); } else { - auto pred_fn = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; - copy_if(pred_fn, tCgCol_flt, tCrCol_flt); + Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); + copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt); } constexpr int FrgSize = size(tCrCol_flt); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 4a8e5f8c..29b9d1d1 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -412,17 +412,13 @@ struct Sm90AuxStore< constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; constexpr int V = cute::min(Alignment, size(MCL)); - Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); - Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); - Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); - auto pred_fn = [&] (auto const&... coords) { - return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); - }; + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); - copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec); + copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); } }; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 265a75ee..8412b503 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -540,6 +540,23 @@ struct HardSwish > { } }; +template +struct HardSwish > { + using T = half_t; + static const bool kIsHeavy = false; + static constexpr float kOneSixth = 0.16666667f; + + CUTLASS_HOST_DEVICE + Array operator()(Array const &value) const { + minimum > mn; + maximum > mx; + multiplies > mul; + plus > add; + + return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(kOneSixth)); + } +}; + template using ScaledHardSwish = Scale>; diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index ab877d4d..d894b114 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -542,12 +542,8 @@ struct VisitorColBroadcast { } } clear(tC_rCol); - Tensor pred = make_tensor(shape(tC_gCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tC_cCol(i)) < m; - } - copy_if(pred, tC_gCol, tC_rCol); + Tensor tC_pCol = cute::lazy::transform(tC_cCol, [&] (auto const& c) { return get<0>(c) < m; }); + copy_if(tC_pCol, tC_gCol, tC_rCol); } template diff --git a/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp b/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp index 7ac336a3..7968849a 100644 --- a/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp +++ b/include/cutlass/experimental/distributed/device/dist_gemm_universal_wrapper.hpp @@ -446,7 +446,7 @@ public: Status construct_graph(bool launch_with_pdl) { -#if ((__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 4)) +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) Status status = Status::kSuccess; // Destroy existing graph, if created diff --git a/include/cutlass/experimental/distributed/device/full_barrier.hpp b/include/cutlass/experimental/distributed/device/full_barrier.hpp index 54b8348d..ab91cf89 100644 --- a/include/cutlass/experimental/distributed/device/full_barrier.hpp +++ b/include/cutlass/experimental/distributed/device/full_barrier.hpp @@ -47,7 +47,7 @@ void launch_full_barrier( cudaStream_t stream, bool launch_with_pdl) { -#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4)) +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 6)) // Legacy (kernel) launch with PDL cudaLaunchAttribute attributes[1]; attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl index 7ebc9b8b..b8824c23 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl @@ -268,7 +268,7 @@ struct CollectiveBuilder< // Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages. static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{}); static constexpr uint32_t AccumulatorPipelineStageCount = (AccumulatorNPerCta == 256) ? 1 : 2; - static constexpr uint32_t SchedulerPipelineStageCount = 1; + static constexpr uint32_t SchedulerPipelineStageCount = 2; using SmemTileShape = cute::Shape; diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index af122b4d..3556fad6 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -238,7 +238,7 @@ struct CollectiveBuilder< static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 1); + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, diff --git a/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl index 8d8a9c5e..fc4aa4a2 100644 --- a/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl +++ b/include/cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl @@ -51,6 +51,8 @@ struct Sm100DenseGemmTmaUmmaCarveout { static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); // CLC (scheduler) response static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); // Tmem ptr storage @@ -64,6 +66,7 @@ struct Sm100DenseGemmTmaUmmaCarveout { CLCPipelineStorage + LoadOrderBarrierStorage + TmemDeallocStorage + + CLCThrottlePipelineStorage + CLCResponseStorage + TmemBasePtrsStorage + TensorMapStorage @@ -80,6 +83,8 @@ struct Sm100SparseGemmTmaUmmaCarveout { static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); // AccumulatorPipeline = PipelineUmmaAsync static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); // Tmem dealloc static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); @@ -87,6 +92,7 @@ struct Sm100SparseGemmTmaUmmaCarveout { cutlass::round_up(LoadOrderBarrierStorage, 16) + cutlass::round_up(CLCPipelineStorage, 16) + cutlass::round_up(AccumulatorPipelineStorage, 16) + + cutlass::round_up(CLCThrottlePipelineStorage, 16) + cutlass::round_up(TmemDeallocStorage, 16), 16)); diff --git a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl index 20714e4c..c7d380a9 100644 --- a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl @@ -371,7 +371,7 @@ struct CollectiveBuilder< // Calculate SMEM matrix A and B buffers' pipeline stages and the accumulator stages. static constexpr uint32_t AccumulatorNPerCta = cute::size<1>(TileShape_MNK{}); static constexpr uint32_t AccumulatorPipelineStageCount = AccumulatorNPerCta > 224 ? 1 : 2; - static constexpr uint32_t SchedulerPipelineStageCount = 1; + static constexpr uint32_t SchedulerPipelineStageCount = 2; using SmemTileShape = cute::Shape; diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index 4ebfdb52..e7f5235a 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -267,7 +267,7 @@ struct CollectiveBuilder< static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); // Grouped GEMM(where Stride type is Stride*) uses specific static tile scheduler. static constexpr bool IsGroupGemm = !cute::is_same_v; - static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 1); + static constexpr uint32_t SchedulerPipelineStageCount = cute::conditional_return(8, 2); static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< ClusterShape_MNK, diff --git a/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl new file mode 100644 index 00000000..4b5858cf --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl @@ -0,0 +1,264 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/builders/sm120_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class ElementScalar, + class TileShapeMNK, + class ScaleShapeMNK, + class MainloopPipelineStorage, + int stages +> +constexpr int +sm120_compute_stage_count_or_override_blockwise(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class ElementScalar, + class TileShapeMNK, + class ScaleShapeMNK, + class MainloopPipelineStorage, + int carveout_bytes +> +constexpr auto +sm120_compute_stage_count_or_override_blockwise(StageCountAutoCarveout stage_count) { + // For F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t + + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto scale_bits = cute::sizeof_bits_v; + constexpr auto mainloop_pipeline_bytes = sizeof(MainloopPipelineStorage); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) + + cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + class ElementA, + class GmemLayoutATagPair, + int AlignmentA, + class ElementB, + class GmemLayoutBTagPair, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm120, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATagPair, + AlignmentA, + ElementB, + GmemLayoutBTagPair, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + not cute::is_tuple_v && not cute::is_tuple_v && + not cute::is_complex_v && not cute::is_complex_v && + cute::is_tuple_v && cute::is_tuple_v && + (cute::is_base_of_v || + cute::is_same_v) && + detail::sm1xx_gemm_is_aligned()>> +{ + + static_assert(detail::is_sm10x_f8f6f4_element() && detail::is_sm10x_f8f6f4_element(), + "SM120 TmaWarpSpecialized blockwise scaling builder currently only supports F8F6F4 MMA."); + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(cute::is_static_v, "Cluster has to be static"); + + using GmemLayoutATag = cute::remove_cvref_t(GmemLayoutATagPair{}))>; + using GmemLayoutSFATag = cute::remove_cvref_t(GmemLayoutATagPair{}))>; + using GmemLayoutBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; + using GmemLayoutSFBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; + + static_assert(cute::depth(cute::remove_pointer_t{}) == 2 and + cute::depth(cute::remove_pointer_t{}) == 2, + "Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)"); + static_assert(size<1, 0>(cute::remove_pointer_t{}) == + size<1, 0>(cute::remove_pointer_t{}), + "SFA and SFB must have equivalent SF vector sizes along K"); + + static constexpr cute::UMMA::Major UmmaMajorA = detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = detail::tag_to_umma_major_B(); + static_assert((UmmaMajorA == UMMA::Major::K && UmmaMajorB == UMMA::Major::K), "Only TN layout is supported."); + + using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{})); + using PermTileN = decltype(cute::min(size<1>(TileShape_MNK{}), _32{})); + + static constexpr bool IsCooperative = !cute::is_base_of_v; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + + // Setup TiledMma + using TiledMma = decltype(cute::make_tiled_mma( + cute::rr_op_selector_sm120(), + AtomLayoutMNK{}, + Tile{} + )); + + // DType check + static constexpr bool UseF8f6f4 = detail::is_sm120_f8f6f4(); + static_assert(UseF8f6f4, "Non-blockscaled collective builder only supports F8F6F4 MMA.\n"); + + // Element type + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideSFA = cutlass::gemm::TagToStrideA_t; + using StrideSFB = cutlass::gemm::TagToStrideB_t; + + static constexpr int ScaleGranularityM = size<0,0>(cute::remove_pointer_t{}); + static constexpr int ScaleGranularityN = size<0,0>(cute::remove_pointer_t{}); + static constexpr int ScaleGranularityK = size<1,0>(cute::remove_pointer_t{}); + + static_assert(size<0>(TileShape_MNK{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M."); + static_assert(size<1>(TileShape_MNK{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N."); + static_assert(size<2>(TileShape_MNK{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K."); + + using BlockTileScale_M = Int(TileShape_MNK{}) / ScaleGranularityM>; + using BlockTileScale_N = Int(TileShape_MNK{}) / ScaleGranularityN>; + using BlockTileScale_K = Int(TileShape_MNK{}) / ScaleGranularityK>; + + using ScaleTileShape = cute::Shape; + + + // Setup Stages and DispatchPolicy + using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; + + static constexpr int PipelineStages = detail::sm120_compute_stage_count_or_override_blockwise< + detail::sm120_smem_capacity_bytes, SmemAllocTypeA, + SmemAllocTypeB, ElementAccumulator, + TileShape_MNK, ScaleTileShape, MainloopPipelineStorage>(StageCountType{}); + static constexpr uint32_t SchedulerPipelineStageCount = 2; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v, StrideA>; + using KernelSchedule = cute::conditional_t, + KernelPtrArrayTmaWarpSpecializedPingpongBlockwiseScalingSm120>, + // Non-PtrArray + cute::conditional_t, + KernelTmaWarpSpecializedPingpongBlockwiseScalingSm120>>; + + using DispatchPolicy = cute::conditional_t, + MainloopSm120TmaWarpSpecializedBlockwiseScaling>; + + using SmemCopyAtomA = Copy_Atom()), SmemAllocTypeA>; + using SmemCopyAtomB = Copy_Atom()), SmemAllocTypeB>; + + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cute::tuple, + ElementB, + cute::tuple, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl index 5426aa4c..b75573a2 100644 --- a/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_mma_builder.inl @@ -66,6 +66,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< + not cute::is_tuple_v && not cute::is_tuple_v && + not cute::is_tuple_v && not cute::is_tuple_v && // Dense Gemm (cute::is_base_of_v || cute::is_base_of_v || diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index d07aca2e..b03c79c8 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -50,6 +50,7 @@ #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_sparse_mma_builder.inl" +#include "cutlass/gemm/collective/builders/sm120_blockwise_mma_builder.inl" #endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 5a1e93eb..f65dd70b 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -67,6 +67,8 @@ #include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp" #include "cutlass/gemm/collective/sm120_sparse_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp" +#include "cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp" #endif // !defined(__CUDACC_RTC__) diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index b51d1256..2665ef1c 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -28,10 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - - - - #pragma once #include "cutlass/cutlass.h" @@ -51,7 +47,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -169,7 +164,6 @@ struct CollectiveMma< using InternalStrideB = cute::remove_pointer_t; static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); - static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || @@ -210,19 +204,15 @@ struct CollectiveMma< AtomThrShapeMNK>; using MainloopPipelineState = typename MainloopPipeline::PipelineState; - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); - static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtomA must evenly divide the tile shape."); static_assert(cute::is_void_v, "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); - static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); - static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, - "SmemLayoutAtom must evenly divide tile shape."); + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtomB must evenly divide the tile shape."); static_assert(cute::is_void_v, "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); @@ -275,8 +265,8 @@ struct CollectiveMma< using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; - using BitTypeElementA = uint_bit_t>; - using BitTypeElementB = uint_bit_t>; + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; using ArrayElementA = cute::conditional_t; using ArrayElementB = cute::conditional_t; @@ -308,15 +298,22 @@ struct CollectiveMma< using TensorMapStorage = typename SharedStorage::TensorMapStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly static constexpr uint32_t SFTransactionBytes = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); - // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly static constexpr uint32_t ABTmaTransactionBytes = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; + template + struct TmemStorage { + AccTensor accumulators; + SfaTensor tCtSFA; + SfbTensor tCtSFB; + }; + // Host side kernel arguments struct Arguments { ArrayElementA const** ptr_A{nullptr}; @@ -401,7 +398,11 @@ struct CollectiveMma< CUTLASS_DEVICE CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) : cluster_shape_(cluster_shape) - , block_rank_in_cluster_(block_rank_in_cluster) { + , block_rank_in_cluster_(block_rank_in_cluster) + , layout_SFA_(params.layout_SFA) + , layout_SFB_(params.layout_SFB) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { if constexpr (IsDynamicCluster) { const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); @@ -613,18 +614,48 @@ struct CollectiveMma< } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE static + auto partition_accumulator_shape() { auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) return acc_shape; } + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } - template - CUTLASS_DEVICE auto - slice_accumulator(cute::Tensor const& accumulators, int stage) { - return accumulators(_,_,_,stage); + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); + Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); + + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + tmem_storage.tCtSFA = tCtSFA; + tmem_storage.tCtSFB = tCtSFB; + + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); + tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); } /// Set up the data needed by this collective for load. @@ -693,9 +724,9 @@ struct CollectiveMma< } else if constexpr (IsCtaN64) { Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); - auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); - auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); } @@ -707,7 +738,6 @@ struct CollectiveMma< Tensor gSFA_mkl = local_tile(mSFA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) - // Partition for this CTA ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); @@ -770,17 +800,15 @@ struct CollectiveMma< } /// Set up the data needed by this collective for mma compute. - template + template CUTLASS_DEVICE auto mma_init( - Params const& params, - [[maybe_unused]] cute::Tensor const& accumulators, - TensorStorage& shared_tensors, - uint32_t const tmem_offset) const { + TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { // Allocate "fragments/descriptors" for A and B matrices - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Allocate "fragments/descriptors" for A and B matrices Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) @@ -792,13 +820,8 @@ struct CollectiveMma< // // Scale Factor // - Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); - // Set tCtSFA and tCtSFB start addresses. Only update the TMEM column address by masking the address with 0x000001FF. - // TMEM allocations for SFA and SFB will always start at DP 0. - tCtSFA.data() = tmem_offset; - Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); - tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); - + Tensor tCtSFA = tmem_storage.tCtSFA; + Tensor tCtSFB = tmem_storage.tCtSFB; // Setup smem descriptors for UTCCP Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); @@ -831,8 +854,10 @@ struct CollectiveMma< TiledMma tiled_mma; if constexpr (IsRuntimeDataType) { - tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; - tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; } return cute::make_tuple( @@ -997,45 +1022,52 @@ struct CollectiveMma< // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; - if (k_tile_count > 0) { // first iteraion - // WAIT on mainloop_pipe_consumer_state until its data are available - // (phase bit flips from mainloop_pipe_consumer_state.phase() value) - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + if constexpr (IsOverlappingAccum) { + // first iteration manual unroll for tmem overlap kernel + if (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - // Compute on k_tile - int read_stage = mainloop_pipe_consumer_state.index(); - // Save current mainlop pipeline read state - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - // Advance mainloop_pipe - ++mainloop_pipe_consumer_state; - --k_tile_count; - skip_wait = k_tile_count <= 0; - // Peek at next iteration - barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - if (cute::elect_one_sync()) { - copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); - copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); - } + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); + } - if constexpr (IsOverlappingAccum) { + // Wait for tmem accumulator buffer to become empty with a flipped phase accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - } - // Unroll the K mode manually so we can set scale C to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma.with(tiled_mma.accumulate_, - tCtSFA(_,_,k_block), - tCtSFB_mma(_,_,k_block)), - tCrA(_,_,k_block,read_stage), - tCrB(_,_,k_block,read_stage), - accumulators); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } - mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + else { + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); } CUTLASS_PRAGMA_NO_UNROLL @@ -1073,6 +1105,7 @@ struct CollectiveMma< accumulators); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } @@ -1273,6 +1306,11 @@ protected: typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; + LayoutSFA layout_SFA_; + LayoutSFB layout_SFB_; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + ClusterShape cluster_shape_; uint32_t block_rank_in_cluster_; }; diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index 1ec07261..79a97bed 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -47,7 +47,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -123,7 +122,7 @@ struct CollectiveMma< "Static cluster shape used: TileShape should be evenly divided by TiledMma"); using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); - static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, "Cta N should be one of 64/128/192/256"); @@ -726,9 +725,9 @@ struct CollectiveMma< } else if constexpr (IsCtaN64) { Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); - auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); - auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); } diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp index f61c9da1..bcf88620 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp @@ -48,7 +48,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -911,9 +910,9 @@ struct CollectiveMma< } else if constexpr (IsCtaN64) { Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); - auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), + auto new_shape = make_shape(make_shape(shape<0,0>(mSFB_tmp), make_shape(_2{} , shape<0,1>(mSFB_tmp))), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); - auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(_0{}, stride<0,1>(mSFB_tmp))), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); } diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp index c0c3ed15..d832a1fc 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -30,7 +30,6 @@ **************************************************************************************************/ - #pragma once #include "cutlass/cutlass.h" @@ -43,12 +42,12 @@ #include "cutlass/trace.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" #include "cute/algorithm/functional.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -131,7 +130,7 @@ struct CollectiveMma< using ElementB = ElementB_; using ElementBMma = typename TiledMma::ValTypeB; using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); @@ -212,8 +211,8 @@ struct CollectiveMma< using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; - using BitTypeElementA = uint_bit_t>; - using BitTypeElementB = uint_bit_t>; + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; using ArrayElementA = cute::conditional_t; using ArrayElementB = cute::conditional_t; @@ -221,6 +220,8 @@ struct CollectiveMma< using RuntimeDataTypeA = cute::conditional_t; using RuntimeDataTypeB = cute::conditional_t; + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { cute::ArrayEngine> smem_A; @@ -246,7 +247,10 @@ struct CollectiveMma< cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); - static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + template + struct TmemStorage { + AccTensor accumulators; + }; // Host side kernel arguments struct Arguments { @@ -298,9 +302,11 @@ struct CollectiveMma< CUTLASS_DEVICE CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) : cluster_shape_(cluster_shape) - , block_rank_in_cluster_(block_rank_in_cluster) { + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; @@ -357,7 +363,6 @@ struct CollectiveMma< auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); - typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( GmemTiledCopyA{}, tensor_a, @@ -421,7 +426,7 @@ struct CollectiveMma< return cutlass::Status::kSuccess; } - template + template static bool can_implement( ProblemShape problem_shapes, @@ -450,17 +455,38 @@ struct CollectiveMma< } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE static + auto partition_accumulator_shape() { - auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) - - return acc_shape; + return partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) } - template - CUTLASS_DEVICE auto - slice_accumulator(cute::Tensor const& accumulators, int stage) { - return accumulators(_,_,_,stage); + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return tmem_storage.accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; } /// Set up the data needed by this collective for load. @@ -535,13 +561,13 @@ struct CollectiveMma< } /// Set up the data needed by this collective for mma compute. - template + template CUTLASS_DEVICE auto mma_init( - Params const& params, - [[maybe_unused]] cute::Tensor const& accumulators, - TensorStorage& shared_tensors, - [[maybe_unused]] uint32_t const tmem_nonaccum_offset) const { + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -550,15 +576,15 @@ struct CollectiveMma< Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE TiledMma tiled_mma; if constexpr (IsRuntimeDataType) { // Update instruction descriptor according to runtime argument. // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. - tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; - tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; } return cute::make_tuple(tiled_mma, tCrA, tCrB); @@ -672,6 +698,8 @@ struct CollectiveMma< // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { @@ -776,9 +804,9 @@ struct CollectiveMma< TmaInternalElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); - cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, prob_shape_B, prob_stride_B); // Convert strides to byte strides @@ -852,6 +880,8 @@ protected: typename Params::TMA_A const* observed_tma_load_a_{nullptr}; typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; ClusterShape cluster_shape_; uint32_t block_rank_in_cluster_; diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp index 8fc171e8..812553af 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp @@ -45,7 +45,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -131,7 +130,7 @@ struct CollectiveMma< using ElementBMma = typename TiledMma::ValTypeB; using StrideB = cute::remove_cvref_t(StridePairB_{}))>; using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; - using InternalStrideB = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; using InternalLayoutSFB = cute::remove_pointer_t; static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); @@ -143,9 +142,9 @@ struct CollectiveMma< "ElementA and ElementB should be both runtime or both static."); static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; - + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); - + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0 and ScaleGranularityM <= size<0>(TileShape{}), "Scale Granularity M must divide Tile Shape"); @@ -166,12 +165,12 @@ struct CollectiveMma< static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(InternalLayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; - + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(CtaShape_MNK{})); using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(CtaShape_MNK{})); @@ -193,9 +192,9 @@ struct CollectiveMma< static constexpr int CopyAlignmentSFA = GmemTiledCopySFA::AtomNumVal::value * sizeof(typename GmemTiledCopySFA::ValType) / sizeof(ElementAccumulator); static constexpr int CopyAlignmentSFB = GmemTiledCopySFB::AtomNumVal::value * sizeof(typename GmemTiledCopySFB::ValType) / sizeof(ElementAccumulator); - static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? + static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? (size<0,1>(InternalLayoutSFA{}.stride()) == 1 ? ScaleGranularityM : ScaleGranularityK) : 1); - static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? + static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? (size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? ScaleGranularityN : ScaleGranularityK) : 1); @@ -383,7 +382,7 @@ struct CollectiveMma< : cluster_shape_(cluster_shape) , block_rank_in_cluster_(block_rank_in_cluster) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; @@ -666,7 +665,7 @@ struct CollectiveMma< auto layout_SFA = [&]() CUTLASS_LAMBDA_FUNC_INLINE { if constexpr (IsGroupedGemmKernel) { return params.layout_SFA[current_group]; - } + } else { return params.layout_SFA; } @@ -675,7 +674,7 @@ struct CollectiveMma< auto layout_SFB = [&]() CUTLASS_LAMBDA_FUNC_INLINE { if constexpr (IsGroupedGemmKernel) { return params.layout_SFB[current_group]; - } + } else { return params.layout_SFB; } @@ -690,14 +689,14 @@ struct CollectiveMma< Tensor SFB_nkl_ident = make_identity_tensor(shape(layout_SFB)); // Tile the tensors and defer the slice - Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) - Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) - Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) - Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) static_assert(rank(decltype(gSFA_mkl){}) == 5); @@ -710,16 +709,16 @@ struct CollectiveMma< ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); - + Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); @@ -731,16 +730,16 @@ struct CollectiveMma< tSFAgSFA_mkl, tSFBgSFB_nkl, tSFAsSFA, tSFBsSFB, tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, - layout_SFA, layout_SFB); + layout_SFA, layout_SFB); } /// Setup data needed for transform CUTLASS_DEVICE auto accum_init( TensorStorage& shared_tensors) const { - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) return cute::make_tuple(sSFA, sSFB); @@ -763,20 +762,20 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); - auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), - get<1>(shape(tCrA_.layout())), - Int{}, + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, _1{}); - auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), - get<1>(shape(tCrB_.layout())), - Int{}, + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, _1{}); - Tensor tCrA = flat_divide(tCrA_, + Tensor tCrA = flat_divide(tCrA_, mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) - Tensor tCrB = flat_divide(tCrB_, + Tensor tCrB = flat_divide(tCrB_, mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE @@ -884,10 +883,10 @@ struct CollectiveMma< load_sf( MainloopSFPipeline mainloop_sf_pipeline, MainloopSFPipelineState mainloop_sf_pipe_producer_state, - cute::tuple const& mainloop_sf_inputs, @@ -921,19 +920,19 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFA); ++i) { - Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); + Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); } - + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFB); ++i) { - Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); + Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); } copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); - mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); __syncwarp(); @@ -949,7 +948,7 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster CUTLASS_DEVICE void load_sf_tail( - MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipeline mainloop_sf_pipeline, MainloopSFPipelineState mainloop_sf_pipe_producer_state) { // Issue the epilogue waits // This helps avoid early exit of ctas in Cluster @@ -1050,7 +1049,7 @@ struct CollectiveMma< class CopyOpT2R, class EpilogueTile > - CUTLASS_DEVICE auto + CUTLASS_DEVICE auto accum( cute::tuple pipelines, cute::tuple consumer_states, @@ -1068,7 +1067,7 @@ struct CollectiveMma< // // PIPELINED Transform // - + Tensor acc = slice_accumulator(accumulators, _0{}); Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) @@ -1100,15 +1099,15 @@ struct CollectiveMma< int thread_idx = threadIdx.x % size(tiled_t2r_epi); - ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); - + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); @@ -1137,10 +1136,10 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); - + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); - + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); ++mainloop_sf_pipe_state; @@ -1166,19 +1165,19 @@ struct CollectiveMma< // Compute tmem load predication if necessary copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); cutlass::arch::fence_view_async_tmem_load(); - + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(full_acc); ++i) { ElementAccumulator scale = scale_a(i) * scale_b(i); full_acc(i) += scale * tTR_PartAcc(i); } } - } + } cutlass::arch::fence_view_async_tmem_load(); accumulator_pipeline.consumer_release(accumulator_pipe_state); // release acc ++accumulator_pipe_state; - } + } --k_tile_count; } @@ -1255,9 +1254,9 @@ struct CollectiveMma< TmaInternalElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); - cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, prob_shape_B, prob_stride_B); // Convert strides to byte strides diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp index 2d23bcde..0a90566d 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp @@ -48,7 +48,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/arch/mma_sm100.hpp" #include "cutlass/trace.h" #include "cutlass/kernel_hardware_info.hpp" @@ -138,13 +137,13 @@ struct CollectiveMma< using ElementA = float; using PackedElementA = float2; using StrideA = StrideA_; - using InternalStrideA = cute::remove_pointer_t; + using InternalStrideA = cute::remove_pointer_t; using ElementAMma = typename TiledMma::ValTypeA; using PackedElementAMma = uint32_t; using ElementB = float; using PackedElementB = float2; using StrideB = StrideB_; - using InternalStrideB = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; using ElementBMma = typename TiledMma::ValTypeB; using PackedElementBMma = uint32_t; using ElementAccumulator = typename TiledMma::ValTypeC; @@ -308,7 +307,7 @@ struct CollectiveMma< // Device side kernel params struct Params { - using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); using TMA_A = decltype(make_tma_atom_A_sm100( @@ -342,11 +341,11 @@ struct CollectiveMma< : cluster_shape_(cluster_shape) , block_rank_in_cluster_(block_rank_in_cluster) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; - } + } else { observed_tma_load_a_ = ¶ms.tma_load_a; observed_tma_load_b_ = ¶ms.tma_load_b; @@ -369,7 +368,7 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), args.dA)); Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), args.dB)); - + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); // Cluster layout for TMA construction auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); @@ -458,7 +457,7 @@ struct CollectiveMma< } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE auto partition_accumulator_shape() { auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) @@ -925,7 +924,7 @@ struct CollectiveMma< CUTLASS_DEVICE auto accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { // Obtain a single accumulator - Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); // Apply epilogue subtiling Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) // Create the TMEM copy for single EpilogueTile. @@ -937,7 +936,7 @@ struct CollectiveMma< Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) - + // Apply epilogue subtiling to bulk accumulator // We need to tile the whole bulk_tmem allocation with EpilogueTile. // The accumulation should be aware of the AccumulatorPipelineStages @@ -967,7 +966,7 @@ struct CollectiveMma< uint32_t skip_wait = 0; auto mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); - + // 1. Global periodic accumulation in registers CUTLASS_PRAGMA_NO_UNROLL for (; k_tile_count > 0; --k_tile_count) { diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp index b0d2d0f6..fe5ee3cd 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -47,7 +47,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp index d86c58be..e76818d4 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -47,7 +47,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -142,9 +141,9 @@ struct CollectiveMma< static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(LayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, size<0,1>(LayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; @@ -204,9 +203,9 @@ struct CollectiveMma< static constexpr int CopyAlignmentSFA = GmemTiledCopySFA::AtomNumVal::value * sizeof(typename GmemTiledCopySFA::ValType) / sizeof(ElementAccumulator); static constexpr int CopyAlignmentSFB = GmemTiledCopySFB::AtomNumVal::value * sizeof(typename GmemTiledCopySFB::ValType) / sizeof(ElementAccumulator); - static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? + static constexpr int AlignmentSFA = CopyAlignmentSFA * (GmemTiledCopySFA::AtomNumVal::value > 1 ? (size<0,1>(LayoutSFA{}.stride()) == 1 ? ScaleGranularityM : ScaleGranularityK) : 1); - static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? + static constexpr int AlignmentSFB = CopyAlignmentSFB * (GmemTiledCopySFB::AtomNumVal::value > 1 ? (size<0,1>(LayoutSFB{}.stride()) == 1 ? ScaleGranularityN : ScaleGranularityK) : 1); @@ -399,7 +398,7 @@ struct CollectiveMma< > struct AccumTransformParams { // for scheduler - + STensorScaleA sSFA; STensorScaleB sSFB; @@ -468,7 +467,7 @@ struct CollectiveMma< , runtime_data_type_a_(params.runtime_data_type_a) , runtime_data_type_b_(params.runtime_data_type_b) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; @@ -567,7 +566,7 @@ struct CollectiveMma< implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); - + if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } @@ -627,7 +626,7 @@ struct CollectiveMma< } /// Set up the data needed by this collective for load. - /// Return load params containing + /// Return load params containing /// gA_mkl - The tiled tma tensor for input A /// gB_nkl - The tiled tma tensor for input B /// tAsA - partitioned smem tensor for A @@ -691,12 +690,12 @@ struct CollectiveMma< } /// Set up the data needed by this collective for load. - /// Return load params containing + /// Return load params containing /// tSFAgSFA_mkl - partitioned gmem tensor for SFA /// tSFBgSFB_nkl - partitioned gmem tensor for SFB /// tSFAIdentSFA_mkl - partitioned identity tensor for SFA in gmem /// tSFBIdentSFB_nkl - partitioned identity tensor for SFB in gmem - /// tSFAsSFA - partitioned smem tensor for SFA + /// tSFAsSFA - partitioned smem tensor for SFA /// tSFBsSFB - partitioned smem tensor for SFB /// layout_SFA - layout of SFA in gmem /// layout_SFB - layout of SFB in gmem @@ -720,14 +719,14 @@ struct CollectiveMma< Tensor SFB_nkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFB)); // Tile the tensors and defer the slice - Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) - Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) - Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) - Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) static_assert(rank(decltype(gSFA_mkl){}) == 5); @@ -740,16 +739,16 @@ struct CollectiveMma< ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) - Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); - + Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); @@ -784,20 +783,20 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(rank(tCrA_) == _4{}); - auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), - get<1>(shape(tCrA_.layout())), - Int{}, + auto mma_tile_shape_A = make_shape(get<0>(shape(tCrA_.layout())), + get<1>(shape(tCrA_.layout())), + Int{}, _1{}); - auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), - get<1>(shape(tCrB_.layout())), - Int{}, + auto mma_tile_shape_B = make_shape(get<0>(shape(tCrB_.layout())), + get<1>(shape(tCrB_.layout())), + Int{}, _1{}); - Tensor tCrA = flat_divide(tCrA_, + Tensor tCrA = flat_divide(tCrA_, mma_tile_shape_A)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_M,MMA_K_PER_SCALE,MMA_K_REST,PIPE) - Tensor tCrB = flat_divide(tCrB_, + Tensor tCrB = flat_divide(tCrB_, mma_tile_shape_B)(_,_,_,_0{},_0{},_0{},_,_); // (MMA,MMA_N,MMA_K_PER_SCALE,MMA_K_REST,PIPE) @@ -830,9 +829,9 @@ struct CollectiveMma< // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; - Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.begin()), + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (ScaleMsPerTile,ScakeKsPerTile,P) - Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.begin()), + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (ScaleNsPerTile,ScaleKsPerTile,P) @@ -852,7 +851,7 @@ struct CollectiveMma< CUTLASS_DEVICE auto load_ab( MainloopABPipeline mainloop_pipeline, - MainloopABPipelineState mainloop_pipe_producer_state, + MainloopABPipelineState mainloop_pipe_producer_state, LoadABParams const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count) { @@ -896,7 +895,7 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster CUTLASS_DEVICE void load_ab_tail( - MainloopABPipeline mainloop_pipeline, + MainloopABPipeline mainloop_pipeline, MainloopABPipelineState mainloop_pipe_producer_state) { // Issue the epilogue waits // This helps avoid early exit of ctas in Cluster @@ -923,7 +922,7 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count) { auto [unused_k_tiles, - tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAgSFA_mkl, tSFBgSFB_nkl, tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, tSFAsSFA, tSFBsSFB, layout_SFA, layout_SFB] = load_inputs; @@ -950,19 +949,19 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFA); ++i) { - Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); + Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); } - + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFB); ++i) { - Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); + Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); } copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); - mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); __syncwarp(); @@ -977,7 +976,7 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster CUTLASS_DEVICE void load_sf_tail( - MainloopSFPipeline mainloop_sf_pipeline, + MainloopSFPipeline mainloop_sf_pipeline, MainloopSFPipelineState mainloop_sf_pipe_producer_state) { // Issue the epilogue waits // This helps avoid early exit of ctas in Cluster @@ -1007,10 +1006,10 @@ struct CollectiveMma< int k_tile_count) { auto [tiled_mma, tCrA, tCrB] = mma_inputs; - auto [mainloop_pipeline, + auto [mainloop_pipeline, accumulator_pipeline] = pipelines; - auto [mainloop_pipe_consumer_state, + auto [mainloop_pipe_consumer_state, accumulator_pipe_producer_state] = pipeline_states; uint32_t skip_wait = k_tile_count <= 0; @@ -1077,7 +1076,7 @@ struct CollectiveMma< class CopyOpT2R, class EpilogueTile > - CUTLASS_DEVICE auto + CUTLASS_DEVICE auto accum( cute::tuple pipelines, cute::tuple consumer_states, @@ -1095,7 +1094,7 @@ struct CollectiveMma< // // PIPELINED Transform // - + Tensor acc = get<0>(slice_accumulator(tmem_storage, _0{})); Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); @@ -1130,15 +1129,15 @@ struct CollectiveMma< int thread_idx = threadIdx.x % size(tiled_t2r_epi); - ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); - + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); @@ -1167,10 +1166,10 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); - + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); - + mainloop_sf_pipeline.consumer_release(mainloop_sf_pipe_state); ++mainloop_sf_pipe_state; @@ -1196,19 +1195,19 @@ struct CollectiveMma< // Compute tmem load predication if necessary copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); cutlass::arch::fence_view_async_tmem_load(); - + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(full_acc); ++i) { ElementAccumulator scale = scale_a(i) * scale_b(i); full_acc(i) += scale * tTR_PartAcc(i); } } - } + } cutlass::arch::fence_view_async_tmem_load(); accumulator_pipeline.consumer_release(accumulator_pipe_state); // release acc ++accumulator_pipe_state; - } + } --k_tile_count; } diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp index 993335e3..54c3bd58 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp @@ -48,7 +48,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/arch/mma_sm100.hpp" #include "cutlass/trace.h" #include "cutlass/kernel_hardware_info.hpp" @@ -313,7 +312,7 @@ struct CollectiveMma< // Device side kernel params struct Params { - using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); using TMA_A = decltype(make_tma_atom_A_sm100( @@ -344,11 +343,11 @@ struct CollectiveMma< : cluster_shape_(cluster_shape) , block_rank_in_cluster_(block_rank_in_cluster) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; - } + } else { observed_tma_load_a_ = ¶ms.tma_load_a; observed_tma_load_b_ = ¶ms.tma_load_b; @@ -366,7 +365,7 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); - + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); // Cluster layout for TMA construction auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); @@ -459,7 +458,7 @@ struct CollectiveMma< } /// Construct A Single Stage's Accumulator Shape - CUTLASS_DEVICE auto + CUTLASS_DEVICE auto partition_accumulator_shape() { auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) @@ -917,7 +916,7 @@ struct CollectiveMma< CUTLASS_DEVICE auto accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { // Obtain a single accumulator - Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); // Apply epilogue subtiling Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) // Create the TMEM copy for single EpilogueTile. @@ -929,7 +928,7 @@ struct CollectiveMma< Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) - + // Apply epilogue subtiling to bulk accumulator // We need to tile the whole bulk_tmem allocation with EpilogueTile. // The accumulation should be aware of the AccumulatorPipelineStages @@ -959,7 +958,7 @@ struct CollectiveMma< uint32_t skip_wait = 0; auto mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); - + // 1. Global periodic accumulation in registers CUTLASS_PRAGMA_NO_UNROLL for (; k_tile_count > 0; --k_tile_count) { diff --git a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp index 5ba16b41..d2d8172f 100644 --- a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp @@ -46,7 +46,6 @@ #include "cute/arch/cluster_sm90.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp index fc7bc988..6d0f5a15 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp @@ -45,7 +45,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -114,7 +113,7 @@ struct CollectiveMma< using ElementB = remove_cvref_t(ElementPairB{}))>; using StrideB = remove_cvref_t(StridePairB{}))>; using InternalStrideB = cute::remove_pointer_t; - + // SFA and SFB using ElementSF = remove_cvref_t(ElementPairA{}))>; using LayoutSFA = remove_cvref_t(StridePairA{}))>; @@ -466,7 +465,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; - + bool implementable = true; if (problem_shapes.is_host_problem_shape_available()) { // Check alignment for all problem sizes @@ -642,7 +641,7 @@ struct CollectiveMma< // Represent the full tensors -- get these from TMA Tensor mA_mkl = params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) Tensor mB_nkl = params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) - + // Represent the full tensor of Scale factors InternalLayoutSFA layout_SFA{}; InternalLayoutSFB layout_SFB{}; @@ -883,7 +882,7 @@ struct CollectiveMma< auto tCsB_stage = tCsB(_,_,_,read_stage); auto tCsSFA_stage = tCsSFA(_,_,_,read_stage); auto tCsSFB_stage = tCsSFB(_,_,_,read_stage); - + auto copy_kblock = [&](auto k_block) { // copy smem->rmem for A/B operand copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); @@ -894,7 +893,7 @@ struct CollectiveMma< fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); - + // Copy smem->rmem for SFA/SFB operand copy(tCsSFA_stage(_,_,k_block), tCrSFA_copy_view(_,_,k_block)); copy(tCsSFB_stage(_,_,k_block), tCrSFB_copy_view(_,_,k_block)); @@ -916,7 +915,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); @@ -943,7 +942,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); @@ -1154,7 +1153,7 @@ struct CollectiveMma< [[maybe_unused]] int32_t next_batch) { return input_tensors; } - + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp index e4acdf32..84d1ab14 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp @@ -45,7 +45,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -798,7 +797,7 @@ struct CollectiveMma< auto tCsB_stage = tCsB(_,_,_,read_stage); auto tCsSFA_stage = tCsSFA(_,_,_,read_stage); auto tCsSFB_stage = tCsSFB(_,_,_,read_stage); - + auto copy_kblock = [&](auto k_block) { // copy smem->rmem for A/B operand copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); @@ -809,7 +808,7 @@ struct CollectiveMma< fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); - + // Copy smem->rmem for SFA/SFB operand copy(tCsSFA_stage(_,_,k_block), tCrSFA_copy_view(_,_,k_block)); copy(tCsSFB_stage(_,_,k_block), tCrSFB_copy_view(_,_,k_block)); @@ -831,7 +830,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); @@ -858,7 +857,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp index 050cec18..6442eb3b 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_sparse_mma_tma.hpp @@ -46,7 +46,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -707,7 +706,7 @@ struct CollectiveMma< Tensor mSFB_tmp = mainloop_params.tma_load_sfb.get_tma_tensor(shape(mainloop_params.layout_SFB)); auto x = stride<0,1>(mSFB_tmp); auto y = ceil_div(shape<0,1>(mSFB_tmp), _2{}); - auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), make_shape( make_shape(_2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(make_stride(_0{}), x)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); @@ -964,10 +963,9 @@ struct CollectiveMma< Tensor tCcE = gmem_thr_copy_E.partition_S(cE_mk); // (CPY,CPY_M,CPY_K) auto [atom, vec] = get_copy_atom_and_common_vec(); // Coordinate comparison for out of bound (OOB) predication - Tensor tZcE = zipped_divide(tCcE, vec); - auto pred_fn = [&](auto coord){ return cute::elem_less(tZcE(Int<0>{}, coord), Shape_MK); }; + Tensor tZpE = cute::lazy::transform(zipped_divide(tCcE, vec), [&](auto const& c){ return cute::elem_less(c, Shape_MK); }); // Copy - cute::copy_if(atom, pred_fn, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); + cute::copy_if(atom, tZpE, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); } else { // Copy diff --git a/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp new file mode 100644 index 00000000..3fc3d583 --- /dev/null +++ b/include/cutlass/gemm/collective/sm120_mma_array_tma_blockwise_scaling.hpp @@ -0,0 +1,1001 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling; + using TileShape = TileShape_; + using ElementA = remove_cvref_t; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + + using ElementB = remove_cvref_t; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalStrideB = cute::remove_pointer_t; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementSF = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = size<0,0>(InternalLayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(InternalLayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(InternalLayoutSFB{}); + + static_assert(size<1, 0>(InternalLayoutSFA{}) == size<1, 0>(InternalLayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M."); + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N."); + static_assert(size<2>(TileShape{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K."); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig(InternalLayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(InternalLayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + static constexpr int AlignmentSFA = 1; + static constexpr int AlignmentSFB = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementSF>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementSF>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout, Int>>; + + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaInternalElementB = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::array_aligned> smem_A; + alignas(1024) cute::array_aligned> smem_B; + cute::array_aligned> smem_scale_A; + cute::array_aligned> smem_scale_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const** ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shapes, Arguments const& args, void* workspace) { + (void) workspace; + + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + constexpr int tma_alignment_bits = 128; + auto init_M = tma_alignment_bits; + auto init_N = tma_alignment_bits; + auto init_K = tma_alignment_bits; + const uint32_t init_L = 1; + TmaInternalElementA const* ptr_A_first_batch = nullptr; + TmaInternalElementB const* ptr_B_first_batch = nullptr; + InternalStrideA stride_a; + InternalStrideB stride_b; + + if constexpr (IsGroupedGemmKernel) { + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M, init_K, init_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N, init_K, init_L), stride_b)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + args.layout_SFA, + reinterpret_cast(args.ptr_SFB), + args.layout_SFB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTmaTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + return (NumInputTmaTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shapes.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + // Ensure complete scale blocks + implementable = implementable && (M % ScaleGranularityM == 0); + implementable = implementable && (N % ScaleGranularityN == 0); + + // We expect full tiles in K + implementable = implementable && (K % size<2>(TileShape{}) == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for blockwise scaling.\n"); + } + } + } + + return implementable; + } + + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params, + ElementSF const* ptr_SFA = nullptr, + ElementSF const* ptr_SFB = nullptr, + InternalLayoutSFA const layout_SFA = InternalLayoutSFA{}, + InternalLayoutSFB const layout_SFB = InternalLayoutSFB{} + ) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + const int32_t init_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,init_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,init_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(ptr_SFA), filter(layout_SFA)); // (Ms, Ks) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(ptr_SFB), filter(layout_SFB)); // (Ns, Ks) + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mSFA_mkl.shape()); + auto scales_n = get<0>(mSFB_nkl.shape()); + + Tensor cSFA_mkl = make_identity_tensor(mSFA_mkl.shape()); + Tensor cSFB_nkl = make_identity_tensor(mSFB_nkl.shape()); + Tensor gSFA = local_tile( + mSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cSFA = local_tile( + cSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gSFB = local_tile( + mSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + Tensor cSFB = local_tile( + cSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tAgA_SFA = thr_scale_copy_a.partition_S(gSFA); + Tensor tAcA_SFA = thr_scale_copy_a.partition_S(cSFA); + Tensor tAsA_SFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tBgB_SFB = thr_scale_copy_b.partition_S(gSFB); + Tensor tBcB_SFB = thr_scale_copy_b.partition_S(cSFB); + Tensor tBsB_SFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tApA_SFA = make_tensor(shape(tAsA_SFA(_,_,0))); + Tensor tBpB_SFB = make_tensor(shape(tBsB_SFB(_,_,0))); + + auto scale_m_lim = std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); + auto scale_n_lim = std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tApA_SFA); ++i) + tApA_SFA(i) = get<0>(tAcA_SFA(i)) < scale_m_lim; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tBpB_SFB); ++i) + tBpB_SFB(i) = get<0>(tBcB_SFB(i)) < scale_n_lim; + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // TMA Multicast Masks + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + int write_stage = smem_pipe_write.index(); + if (lane_predicate) { + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + // Copy scale tensors + copy_if(scale_copy_a, tApA_SFA, tAgA_SFA(_,_,*k_tile_iter), tAsA_SFA(_,_,write_stage)); + copy_if(scale_copy_b, tBpB_SFB, tBgB_SFB(_,_,*k_tile_iter), tBsB_SFB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + __syncwarp(); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + FrgTensorC tmp_accum; + clear(accum); + clear(tmp_accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),TileShape_N,stage) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (TileShape_M,(ScaleGranularityN,ScaleNsPerTile),stage) + + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCsScaleAViewAsC = thread_mma.partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsScaleBViewAsC = thread_mma.partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + + // + // Copy Atom A and B retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + }; + + auto copy_scale_s2r = [&](auto read_stage) { + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementSF scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementSF scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + }; + + auto rescale = [&]() { + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementSF scale_ab = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * scale_ab; + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tmp_accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + copy_scale_s2r(read_stage); + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + if (k_block == K_BLOCK_MAX - 1) { + rescale(); + copy_scale_s2r(read_stage); + } + + }); + + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); + rescale(); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } + + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + TmaInternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + + TmaInternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + Params const& mainloop_params, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if constexpr (IsGroupedGemmKernel) { + return load_init( + problem_shape_mnkl, + mainloop_params, + mainloop_params.ptr_SFA[next_batch], + mainloop_params.ptr_SFB[next_batch], + mainloop_params.layout_SFA[next_batch], + mainloop_params.layout_SFB[next_batch] + ); + } + else { + auto [gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl] = input_tensors; + + mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA[next_batch]), mainloop_params.layout_SFA[next_batch]); + mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB[next_batch]), mainloop_params.layout_SFB[next_batch]); + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm120_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_mma_tma.hpp index 76c1028d..65f83330 100644 --- a/include/cutlass/gemm/collective/sm120_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_mma_tma.hpp @@ -44,7 +44,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -113,7 +112,7 @@ struct CollectiveMma< using RuntimeDataTypeA = void*; using RuntimeDataTypeB = void*; - + static constexpr int ThreadCount = size(TiledMma{}); using MainloopPipeline = cutlass::PipelineTmaAsync; @@ -505,7 +504,7 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); auto tCsA_stage = tCsA(_,_,_,read_stage); auto tCsB_stage = tCsB(_,_,_,read_stage); - + auto copy_kblock = [&](auto k_block) { // copy smem->rmem for A/B operand copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); @@ -533,7 +532,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); @@ -558,7 +557,7 @@ struct CollectiveMma< for_each(make_int_sequence{}, [&] (auto k_block) { auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); - + if (k_block == K_BLOCK_MAX - 1) { cutlass::arch::NamedBarrier::sync( thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); diff --git a/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp new file mode 100644 index 00000000..2f77d664 --- /dev/null +++ b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp @@ -0,0 +1,779 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include "cutlass/numeric_types.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + int SchedulerPipelineStageCount, + class ClusterShape, + class KernelScheduleType, + class TileShape_, + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm120TmaWarpSpecializedBlockwiseScaling, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopSm120TmaWarpSpecializedBlockwiseScaling; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using ElementB = ElementB_; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using TiledMma = TiledMma_; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementSF = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using RuntimeDataTypeA = void*; + using RuntimeDataTypeB = void*; + + static constexpr int ThreadCount = size(TiledMma{}); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 33; + + static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutSFB{}); + + static_assert(size<1, 0>(LayoutSFA{}) == size<1, 0>(LayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0, "Scale Granularity M must evenly divide the tile shape M."); + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0, "Scale Granularity N must evenly divide the tile shape N."); + static_assert(size<2>(TileShape{}) == ScaleGranularityK , "Scale Granularity K must be equal to the tile shape K."); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig(LayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(LayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + static constexpr int AlignmentSFA = 1; + static constexpr int AlignmentSFB = 1; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for A operand smem->rmem reads."); + static_assert(not cute::is_void_v, + "SM120 mainloop must specify a copy atom for B operand smem->rmem reads."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + // we can have partial tiles in M or N, so don't vectorize those loads + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementSF>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementSF>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout, Int>>; + + + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + not cute::is_base_of::value, + "MMA atom must source both A and B operands from rmem for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm120_f8f6f4(); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + using TmaInternalElementA = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + using TmaInternalElementB = cute::conditional_t, + cutlass::tfloat32_t, + cute::conditional_t, + cutlass::detail::float_e2m1_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e2m3_unpacksmem_t, + cute::conditional_t, + cutlass::detail::float_e3m2_unpacksmem_t, + uint_bit_t>>>>>; + + using SmemAllocTypeA = cute::conditional_t; + using SmemAllocTypeB = cute::conditional_t; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesMK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutA{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = static_cast( + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutB{})) * sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + alignas(1024) cute::array_aligned> smem_A; + alignas(1024) cute::array_aligned> smem_B; + cute::array_aligned> smem_scale_A; + cute::array_aligned> smem_scale_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + // Block scaling factors for A and B + ElementSF const* ptr_SFA; + LayoutSFA layout_SFA; + ElementSF const* ptr_SFB; + LayoutSFB layout_SFB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + TmaTransactionBytesMK, + TmaTransactionBytesNK, + args.ptr_SFA, + args.layout_SFA, + args.ptr_SFB, + args.layout_SFB + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + // Ensure complete scale blocks + implementable = implementable && (M % ScaleGranularityM == 0); + implementable = implementable && (N % ScaleGranularityN == 0); + + // We expect full tiles in K + implementable = implementable && (K % size<2>(TileShape{}) == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the alignment requirements for blockwise scaling.\n"); + } + + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), filter(mainloop_params.layout_SFA)); // (Ms, Ks) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), filter(mainloop_params.layout_SFB)); // (Ns, Ks) + + return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class TensorSFA, class TensorSFB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mSFA_mkl = get<2>(load_inputs); + Tensor mSFB_nkl = get<3>(load_inputs); + auto scales_m = get<0>(mSFA_mkl.shape()); + auto scales_n = get<0>(mSFB_nkl.shape()); + + Tensor cSFA_mkl = make_identity_tensor(mSFA_mkl.shape()); + Tensor cSFB_nkl = make_identity_tensor(mSFB_nkl.shape()); + Tensor gSFA = local_tile( + mSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor cSFA = local_tile( + cSFA_mkl, make_tile(Int{}), + make_coord(m_coord,_,l_coord)); + Tensor gSFB = local_tile( + mSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + Tensor cSFB = local_tile( + cSFB_nkl, make_tile(Int{}), + make_coord(n_coord,_,l_coord)); + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, + Layout>{}, Layout>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, + Layout>{}, Layout>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); + + Tensor tAgA_SFA = thr_scale_copy_a.partition_S(gSFA); + Tensor tAcA_SFA = thr_scale_copy_a.partition_S(cSFA); + Tensor tAsA_SFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tBgB_SFB = thr_scale_copy_b.partition_S(gSFB); + Tensor tBcB_SFB = thr_scale_copy_b.partition_S(cSFB); + Tensor tBsB_SFB = thr_scale_copy_b.partition_D(sSFB); + + Tensor tApA_SFA = make_tensor(shape(tAsA_SFA(_,_,0))); + Tensor tBpB_SFB = make_tensor(shape(tBsB_SFB(_,_,0))); + + auto scale_m_lim = std::min(scales_m, (m_coord + 1) * ScaleMsPerTile); + auto scale_n_lim = std::min(scales_n, (n_coord + 1) * ScaleNsPerTile); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tApA_SFA); ++i) + tApA_SFA(i) = get<0>(tAcA_SFA(i)) < scale_m_lim; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tBpB_SFB); ++i) + tBpB_SFB(i) = get<0>(tBcB_SFB(i)) < scale_n_lim; + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + // TMA Multicast Masks + Layout cta_layout_mnk = make_layout(ClusterShape{}); + auto cta_coord_mnk = cta_layout_mnk.get_flat_coord(block_rank_in_cluster); + + uint16_t mcast_mask_a = create_tma_multicast_mask<1>(cta_layout_mnk, cta_coord_mnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<0>(cta_layout_mnk, cta_coord_mnk); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + int write_stage = smem_pipe_write.index(); + if (lane_predicate) { + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + + // Copy scale tensors + copy_if(scale_copy_a, tApA_SFA, tAgA_SFA(_,_,*k_tile_iter), tAsA_SFA(_,_,write_stage)); + copy_if(scale_copy_b, tBpB_SFB, tBgB_SFB(_,_,*k_tile_iter), tBsB_SFB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + FrgTensorC tmp_accum; + clear(accum); + clear(tmp_accum); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),TileShape_N,stage) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (TileShape_M,(ScaleGranularityN,ScaleNsPerTile),stage) + + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thread_mma.partition_fragment_B(sB(_,_,Int<0>{})); // (MMA,MMA_N,MMA_K) + + Tensor tCsScaleAViewAsC = thread_mma.partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsScaleBViewAsC = thread_mma.partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE) + + // + // Copy Atom A and B retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S( + as_position_independent_swizzle_tensor(sA)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + auto smem_tiled_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S( + as_position_independent_swizzle_tensor(sB)); // (CPY,CPY_M,CPY_K,PIPE) + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); // (CPY,CPY_M,CPY_K) + + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_,_,_,_0{})); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(accum)); + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); + + // + // PIPELINED MAIN LOOP + // + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + int read_stage = smem_pipe_read.index(); + auto tCsA_stage = tCsA(_,_,_,read_stage); + auto tCsB_stage = tCsB(_,_,_,read_stage); + + auto copy_kblock = [&](auto k_block) { + // copy smem->rmem for A/B operand + copy(smem_tiled_copy_A, tCsA_stage(_,_,k_block), tCrA_copy_view(_,_,k_block)); + copy(smem_tiled_copy_B, tCsB_stage(_,_,k_block), tCrB_copy_view(_,_,k_block)); + + // Left shift A,B for FP4 + using MMAOp = typename TiledMma::MMA_Op; + fp4_shift_A(MMAOp{}, tCrA_copy_view(_,_,k_block)); + fp4_shift_B(MMAOp{}, tCrB_copy_view(_,_,k_block)); + }; + + auto copy_scale_s2r = [&](auto read_stage) { + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementSF scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementSF scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + }; + + auto rescale = [&]() { + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementSF scale_ab = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * scale_ab; + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum); ++i) { + accum(i) += tmp_accum(i) * tCrScaleAViewAsC(i) * tCrScaleBViewAsC(i); + tmp_accum(i) = 0; + } + } + }; + + auto gemm_kblock = [&](auto k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tmp_accum); + }; + + pipeline.consumer_wait(smem_pipe_read); + copy_scale_s2r(read_stage); + copy_kblock(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + read_stage = smem_pipe_read.index(); + tCsA_stage = tCsA(_,_,_,read_stage); + tCsB_stage = tCsB(_,_,_,read_stage); + pipeline.consumer_wait(smem_pipe_read); + } + + copy_kblock(k_block_next); + gemm_kblock(k_block); + + if (k_block == K_BLOCK_MAX - 1) { + rescale(); + copy_scale_s2r(read_stage); + } + + }); + + } // k_tile_count + + // + // Hoist out last k_tile + // + for_each(make_int_sequence{}, [&] (auto k_block) { + + auto k_block_next = ((k_block + 1) == K_BLOCK_MAX) ? 0 : (k_block + 1); + + if (k_block == K_BLOCK_MAX - 1) { + cutlass::arch::NamedBarrier::sync( + thr_size(tiled_mma), cutlass::arch::ReservedNamedBarriers::Sm120MainloopBarrier); + // UNLOCK smem_pipe_read, done _computing_ on it + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + if (k_block_next > 0) { + copy_kblock(k_block_next); + } + gemm_kblock(k_block); + + }); + rescale(); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline, PipelineState, int) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp index b3566e2e..3316308b 100644 --- a/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_sparse_mma_tma.hpp @@ -45,7 +45,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/functional.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -129,14 +128,14 @@ struct CollectiveMma< using RuntimeDataTypeA = void*; using RuntimeDataTypeB = void*; - + static constexpr int ThreadCount = size(TiledMma{}); static constexpr int ElementAMmaSparsity = ElementAMma::sparsity; static constexpr int ElementEMmaSparsity = ElementEMma::sparsity; // Asymmetric buffering - // Tensor A/B could have different buffering, with TILEK, and STAGEs. - // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's + // Tensor A/B could have different buffering, with TILEK, and STAGEs. + // It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's // pipeline keep same steps when procude / consume data. static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1; @@ -418,7 +417,7 @@ struct CollectiveMma< make_tile(make_layout(size<1>(thr_layout_vmnk)), make_layout(size<3>(thr_layout_vmnk)))); auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) - + // Fragment layout return thr_tensor; } @@ -694,16 +693,15 @@ struct CollectiveMma< if constexpr (IsELoadPred) { // Get predication based on logical element coordinates. Tensor cE_mk = local_tile( - make_identity_tensor(Shape_MK), - make_shape(get<0>(TileShape{}), get<2>(TileShape{})), + make_identity_tensor(Shape_MK), + make_shape(get<0>(TileShape{}), get<2>(TileShape{})), make_shape(m_coord, k_coord)); // (BLK_M, BLK_K) Tensor tCcE = gmem_thr_copy_E.partition_S(cE_mk); // (CPY,CPY_M,CPY_K) auto [atom, vec] = get_copy_atom_and_common_vec(); // Coordinate comparison for out of bound (OOB) predication - Tensor tZcE = zipped_divide(tCcE, vec); - auto pred_fn = [&](auto coord){ return cute::elem_less(tZcE(Int<0>{}, coord), Shape_MK); }; + Tensor tZpE = cute::lazy::transform(zipped_divide(tCcE, vec), [&](auto const& c){ return cute::elem_less(c, Shape_MK); }); // Copy - cute::copy_if(atom, pred_fn, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); + cute::copy_if(atom, tZpE, zipped_divide(tCgE, vec), zipped_divide(tCrE_copy_view, vec)); } else { // Copy @@ -712,7 +710,7 @@ struct CollectiveMma< } return tCrE; } - + /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template < @@ -849,7 +847,7 @@ struct CollectiveMma< // Copy E from SMEM to register auto copy_E = [&](auto m_block, auto k_block) CUTLASS_LAMBDA_FUNC_INLINE { // copy smem->rmem for E operand - copy( recast(tCsE(_,m_block,k_block,smem_pipe_read_mk.index())), + copy( recast(tCsE(_,m_block,k_block,smem_pipe_read_mk.index())), recast(tCrE_copy_view(_,m_block,k_block))); }; @@ -877,8 +875,8 @@ struct CollectiveMma< copy_E(m_block, k_block); // Gemm - cute::gemm(tiled_mma, - make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), tCrB(_,n_block,k_block), accum(_,m_block,n_block)); }); @@ -914,8 +912,8 @@ struct CollectiveMma< copy_transform_A(m_block, k_block); // Gemm - cute::gemm(tiled_mma, - make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block), tCrE(_,m_block,k_block)), tCrB(_,n_block,k_block), accum(_,m_block,n_block)); }); @@ -941,8 +939,8 @@ struct CollectiveMma< copy_transform_A(m_block, k_block_a); // Gemm - cute::gemm(tiled_mma, - make_zip_tensor(tCrA(_,m_block,k_block_a), tCrE(_,m_block,k_block_a)), + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,m_block,k_block_a), tCrE(_,m_block,k_block_a)), tCrB(_,n_block,k_block), accum(_,m_block,n_block)); }); @@ -970,7 +968,7 @@ struct CollectiveMma< gemm_loop_with_SmemE(); } // Case when A/B with different stages, and keep E in GMEM. - else { + else { gemm_loop_with_GmemE(); } // end if diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index c2eda0ab..a0c8f2a8 100644 --- a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -37,7 +37,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" -#include "cute/tensor_predicate.hpp" #include "cutlass/gemm/collective/collective_mma_decl.hpp" diff --git a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp index e8de7c31..97758404 100644 --- a/include/cutlass/gemm/collective/sm80_mma_multistage.hpp +++ b/include/cutlass/gemm/collective/sm80_mma_multistage.hpp @@ -36,7 +36,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" @@ -100,7 +99,7 @@ struct CollectiveMma< using TransformA = TransformA_; using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - // Follow the change in TestSmall: TileShape switch to CtaShape + // Follow the change in TestSmall: TileShape switch to CtaShape // For sm80 arch, CtaShape should euqal to TileShape using CtaShape_MNK = TileShape; @@ -318,7 +317,7 @@ struct CollectiveMma< copy(gmem_tiled_copy_A, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); copy(gmem_tiled_copy_B, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); cp_async_fence(); - + // Advance the tile --k_tile_count; if (k_tile_count > 0) { ++k_tile_iter; } @@ -390,7 +389,7 @@ struct CollectiveMma< Stages, ClusterShape_>; using TileShape = TileShape_; - // Follow the change in TestSmall: TileShape switch to CtaShape + // Follow the change in TestSmall: TileShape switch to CtaShape // In legacy arch, it should be same using CtaShape_MNK = TileShape; using ElementA = ElementA_; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index dc30ae56..653db90a 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -43,7 +43,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -95,7 +94,7 @@ public: ConvertAndScale, ConvertAndScaleWithZero }; - + // // Type Aliases // @@ -105,10 +104,10 @@ public: private: template friend struct detail::MixedInputUtils; - using CollectiveType = CollectiveMma; public: - static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); - + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; static constexpr bool IsATransformed = cute::is_tuple::value; @@ -140,23 +139,23 @@ public: using InternalStrideA = cute::remove_pointer_t; using StrideB = StrideB_; using InternalStrideB = cute::remove_pointer_t; - + using StrideScale = cute::Stride, int64_t, int64_t>; using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; - static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)) || + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)) || (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), "The transformed type must be K-major."); static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || - ((cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value) && - (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + ((cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), "The unscaled element must be 2 bytes OR both inputs must be K-major"); - static_assert(cutlass::gemm::detail::is_mn_major(), + static_assert(cutlass::gemm::detail::is_mn_major(), "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); - + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; @@ -224,10 +223,10 @@ public: /// Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalSwappedStrideA{})); using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalSwappedStrideB{})); - + // It is assumed that the scales and zero-points share the same smem layout using SmemLayoutScale = decltype(tile_to_shape( - SmemLayoutAtomScale{}, + SmemLayoutAtomScale{}, make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); @@ -245,18 +244,18 @@ public: static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); private: - static constexpr ConversionMode + static constexpr ConversionMode get_conversion_mode() { if constexpr (cute::is_void_v) { return ConversionMode::DirectConvert; - } + } else if constexpr (cute::is_void_v) { return ConversionMode::ConvertAndScale; } else { return ConversionMode::ConvertAndScaleWithZero; } - } + } public: static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); @@ -264,7 +263,7 @@ public: KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; - static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); @@ -341,7 +340,7 @@ public: SmemLayoutScale{}(_,_,cute::Int<0>{}), ScaleTileShape{}, _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel - + TMA_A tma_load_a; TMA_B tma_load_b; uint32_t tma_transaction_bytes = TmaTransactionBytes; @@ -415,7 +414,7 @@ public: dA = InternalSwappedStrideA{}; if constexpr (is_layout::value) { dA = make_layout( - transform_leaf(dA.shape(), [](auto x){ + transform_leaf(dA.shape(), [](auto x){ if constexpr (not is_static_v) { return static_cast(1); } else { @@ -521,15 +520,15 @@ public: _1{}); // mcast along N mode for this M load, if any return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); - - } + + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); - } + } } template @@ -545,15 +544,15 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies return calculate_workspace_size(2); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale tensormap copies return calculate_workspace_size(3); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale and zeros tensormap copies return calculate_workspace_size(4); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in get_workspace_size."); } @@ -612,7 +611,7 @@ public: constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); implementable = implementable && (args.ptr_Z != nullptr); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); } @@ -661,7 +660,7 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return cute::make_tuple(gA_mkl, gB_nkl); - } + } else if constexpr (ModeHasScales) { const int scale_mn = SwapAB ? N : M; auto scale_k = mainloop_params.scale_k; @@ -678,7 +677,7 @@ public: else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); } @@ -694,7 +693,7 @@ public: CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, cute::tuple const& input_tensormaps, @@ -707,15 +706,15 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); static_assert(sizeof... (TMs) == 2, "Direct convert needs two tensormaps"); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); static_assert(sizeof... (TMs) == 3, "Scaled convert needs three tensormaps"); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); static_assert(sizeof... (TMs) == 4, "Scaled and zero convert needs four tensormaps"); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); } @@ -809,7 +808,7 @@ public: if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { // Nothing extra to do - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); @@ -819,8 +818,8 @@ public: } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } - } + } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); } @@ -839,9 +838,9 @@ public: // Issue the epilogue waits if (lane_predicate) { // This helps avoid early exit of blocks in Cluster. - // Waits for all stages to either be released (all + // Waits for all stages to either be released (all // Consumer UNLOCKs), or if the stage was never used - // then it would just be acquired since the phase was + // then it would just be acquired since the phase was // still inverted from make_producer_start_state. pipeline.producer_tail(smem_pipe_write); } @@ -875,7 +874,7 @@ public: int warp_idx = canonical_warp_idx_sync(); [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; - + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) @@ -889,7 +888,7 @@ public: // Layout of warp group to thread mapping static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; @@ -938,7 +937,7 @@ public: CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - + // // PIPELINED MAIN LOOP // @@ -967,15 +966,15 @@ public: // copy smem->rmem for A operand - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); if (K_BLOCK_MAX > 1) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); } - + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -986,26 +985,26 @@ public: warpgroup_commit_batch(); if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } - } + } --k_tile_count; if (k_tile_count > 0) { - // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. pipeline.consumer_wait(smem_pipe_read, barrier_token); - - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); - - warpgroup_wait(); + + warpgroup_wait(); Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } } @@ -1030,7 +1029,7 @@ public: // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); @@ -1047,18 +1046,18 @@ public: barrier_token = pipeline.consumer_try_wait(smem_pipe_read); } - if (k_block == K_BLOCK_MAX - 1) { + if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_wait(smem_pipe_read, barrier_token); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); - } + } else { if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); @@ -1078,7 +1077,7 @@ public: int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -1097,7 +1096,7 @@ public: } if (k_block < K_BLOCK_MAX - 2) { - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { @@ -1117,7 +1116,7 @@ public: k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); @@ -1153,7 +1152,7 @@ public: copy(recast(pA_tensormap), recast(sA_tensormap)); copy(recast(pB_tensormap), recast(sB_tensormap)); } - + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{}); Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{}); @@ -1225,7 +1224,7 @@ public: const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl)); const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl)); const uint32_t K = get<2>(problem_shape_mnkl); - + // Replace all dims for consistency constexpr int MaxTensorRank = 5; cute::array prob_shape_A = {1,1,1,1,1}; @@ -1243,23 +1242,23 @@ public: SwappedElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N,K,Int<1>{}), mainloop_params.ptr_dB[next_group])); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { NonVoidElementScale const* ptr_S = nullptr; auto scale_k = ceil_div(K, mainloop_params.chunk_size); Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, prob_shape_scale, prob_stride_scale); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { ElementZero const* ptr_Z = nullptr; auto scale_k = ceil_div(K, mainloop_params.chunk_size); Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, prob_shape_zero, prob_stride_zero); } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ @@ -1300,7 +1299,7 @@ public: } else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); - } + } } template diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index da16d118..6786cec5 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -359,7 +358,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, cute::tuple const& input_tensormaps, @@ -451,9 +450,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { // This helps avoid early exit of blocks in Cluster. - // Waits for all stages to either be released (all + // Waits for all stages to either be released (all // Consumer UNLOCKs), or if the stage was never used - // then it would just be acquired since the phase was + // then it would just be acquired since the phase was // still inverted from make_producer_start_state. pipeline.producer_tail(smem_pipe_write); } @@ -489,10 +488,10 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; @@ -611,7 +610,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); @@ -690,9 +689,9 @@ struct CollectiveMma< InternalElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); // Convert strides to byte strides diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp index 53348dff..916c6db8 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/tensor.hpp" #include "cute/numeric/arithmetic_tuple.hpp" @@ -490,10 +489,10 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; @@ -620,7 +619,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); @@ -699,9 +698,9 @@ struct CollectiveMma< ElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); // Convert strides to byte strides diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 6cec1862..67c82688 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -43,7 +43,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/detail/blockwise_scale_layout.hpp" @@ -168,11 +167,11 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - // Block scaling gmem-to-smem copy atom + // Block scaling gmem-to-smem copy atom // we can have partial tiles in M or N, so don't vectorize those loads using CopyAtomSFA = Copy_Atom, ElementBlockScale>; using CopyAtomSFB = Copy_Atom, ElementBlockScale>; - + static constexpr int AlignmentSFA = 1; static constexpr int AlignmentSFB = 1; @@ -265,7 +264,7 @@ struct CollectiveMma< InternalElementB const** ptr_B; StrideB dB; // Block scaling factors for A and B - ElementBlockScale const** ptr_SFA; + ElementBlockScale const** ptr_SFA; LayoutSFA layout_SFA; ElementBlockScale const** ptr_SFB; LayoutSFB layout_SFB; @@ -423,9 +422,9 @@ struct CollectiveMma< // Make the tiled views of scale tensors - Tensor mSFA_mkl = make_tensor(make_gmem_ptr(ptr_SFA), + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(ptr_SFA), ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, init_L))); // (scale_m,k,l) - Tensor mSFB_nkl = make_tensor(make_gmem_ptr(ptr_SFB), + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(ptr_SFB), ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, init_L))); // (scale_n,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl); @@ -443,7 +442,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, cute::tuple const& input_tensormaps, @@ -457,9 +456,9 @@ struct CollectiveMma< if (lane_predicate) { Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (BLK_M,BLK_K,P) - Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (BLK_N,BLK_K,P) // @@ -561,9 +560,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { // This helps avoid early exit of blocks in Cluster. - // Waits for all stages to either be released (all + // Waits for all stages to either be released (all // Consumer UNLOCKs), or if the stage was never used - // then it would just be acquired since the phase was + // then it would just be acquired since the phase was // still inverted from make_producer_start_state. pipeline.producer_tail(smem_pipe_write); } @@ -579,11 +578,11 @@ struct CollectiveMma< CUTLASS_DEVICE void load_auxiliary( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, @@ -591,9 +590,9 @@ struct CollectiveMma< uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); - Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (BLK_M,BLK_K,P) - Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (BLK_N,BLK_K,P) // Partition the inputs based on the current block coordinates. @@ -741,22 +740,22 @@ struct CollectiveMma< // Block scaling Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), make_layout( - make_shape(shape<0>(SmemLayoutSFA{}), - get<1>(TileShape{}), - make_shape(shape<1>(SmemLayoutSFA{}), + make_shape(shape<0>(SmemLayoutSFA{}), + get<1>(TileShape{}), + make_shape(shape<1>(SmemLayoutSFA{}), shape<2>(SmemLayoutSFA{}))), - make_stride(stride<0>(SmemLayoutSFA{}), _0{}, - make_stride(stride<1>(SmemLayoutSFA{}), + make_stride(stride<0>(SmemLayoutSFA{}), _0{}, + make_stride(stride<1>(SmemLayoutSFA{}), stride<2>(SmemLayoutSFA{}))) )); // (BLK_M,BLK_N,(BLK_K,P)) Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), make_layout( - make_shape(get<0>(TileShape{}), - shape<0>(SmemLayoutSFB{}), - make_shape(shape<1>(SmemLayoutSFB{}), + make_shape(get<0>(TileShape{}), + shape<0>(SmemLayoutSFB{}), + make_shape(shape<1>(SmemLayoutSFB{}), shape<2>(SmemLayoutSFB{}))), - make_stride(_0{}, - stride<0>(SmemLayoutSFB{}), - make_stride(stride<1>(SmemLayoutSFB{}), + make_stride(_0{}, + stride<0>(SmemLayoutSFB{}), + make_stride(stride<1>(SmemLayoutSFB{}), stride<2>(SmemLayoutSFB{}))) )); // (BLK_M,BLK_N,(BLK_K,P)) @@ -767,10 +766,10 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; @@ -1139,9 +1138,9 @@ struct CollectiveMma< InternalElementB const* ptr_B = nullptr; Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, prob_shape_A, prob_stride_A); - cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, prob_shape_B, prob_stride_B); // Convert strides to byte strides diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp index fd13c409..4289bc81 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +161,7 @@ struct CollectiveMma< static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, - InternalElementB{}, cute::bool_constant{})); + InternalElementB{}, cute::bool_constant{})); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); static_assert(not cute::is_base_of::value && @@ -187,7 +186,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct<256, _0> { + struct TensorStorage : cute::aligned_struct<256, _0> { cute::array_aligned, 256> smem_A; cute::array_aligned, 256> smem_B; } tensors; @@ -264,7 +263,7 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < @@ -275,7 +274,7 @@ struct CollectiveMma< > CUTLASS_DEVICE void load( - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, TensorA const& gA_in, TensorB const& gB_in, @@ -358,7 +357,7 @@ struct CollectiveMma< clear(tBsB(_,_,k,write_stage)); } } - + ++k_tile_iter; --k_tile_count; @@ -392,13 +391,13 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail( - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write) { // Issue the epilogue waits /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -432,7 +431,7 @@ struct CollectiveMma< // Obtain warp index int warp_idx = canonical_warp_idx_sync(); [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; - + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -448,11 +447,11 @@ struct CollectiveMma< // Layout of warp group to thread mapping static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -499,8 +498,8 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( - warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, - InternalSmemLayoutAtomB{}, InternalElementB{}, + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, cute::bool_constant{}); warpgroup_fence_operand(accum); @@ -535,8 +534,8 @@ struct CollectiveMma< } warpgroup_wait<2>(); - - + + if (k_tile_count - 1 > 0) { if (!skip_wait) { pipeline.consumer_wait(smem_pipe_read); @@ -589,7 +588,7 @@ struct CollectiveMma< transpose(sB, gmma_sB, read_stage, 1); } } - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); @@ -638,7 +637,7 @@ struct CollectiveMma< ++smem_pipe_release; } } - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); @@ -659,7 +658,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp index fd02d043..fbbe971c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -38,7 +38,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/trace.h" @@ -191,7 +190,7 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; - + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < @@ -202,7 +201,7 @@ struct CollectiveMma< > CUTLASS_DEVICE void load( - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, TensorA const& gA_in, TensorB const& gB_in, @@ -318,13 +317,13 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail( - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write) { // Issue the epilogue waits /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -363,14 +362,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -461,7 +460,7 @@ struct CollectiveMma< pipeline.consumer_wait(smem_pipe_read, barrier_token); int read_stage = smem_pipe_read.index(); - + warpgroup_fence_operand(accum); warpgroup_arrive(); // (V,M,K) x (V,N,K) => (V,M,N) @@ -491,7 +490,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index 898fe828..f8e05437 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -45,7 +45,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -170,7 +169,7 @@ struct CollectiveMma< static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, - InternalElementB{}, cute::bool_constant{})); + InternalElementB{}, cute::bool_constant{})); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); static_assert(not cute::is_base_of::value && @@ -207,8 +206,8 @@ struct CollectiveMma< static_assert(!uses_universal_transposition(), "Warp specialized ARF kernels have not supported universal B transposition yet."); - - static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); @@ -216,7 +215,7 @@ struct CollectiveMma< struct SharedStorage { - struct TensorStorage : cute::aligned_struct { + struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentA> smem_A; cute::array_aligned, SmemAlignmentB> smem_B; } tensors; @@ -330,7 +329,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool implementable = true; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); @@ -410,7 +409,7 @@ struct CollectiveMma< // // Prepare the TMA loads for A and B // - + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; @@ -483,9 +482,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -518,15 +517,15 @@ struct CollectiveMma< // Obtain warp index int warp_idx = canonical_warp_idx_sync(); [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; - + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) // If TransposeB, GMMA will read from transposed B layout SMEM - Tensor gmma_sB_position_dependent = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), + Tensor gmma_sB_position_dependent = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor gmma_sB = as_position_independent_swizzle_tensor(gmma_sB_position_dependent); // (BLK_N,BLK_K,PIPE) @@ -537,11 +536,11 @@ struct CollectiveMma< // Layout of warp group to thread mapping static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -590,12 +589,12 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( - warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, - InternalSmemLayoutAtomB{}, InternalElementB{}, + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, cute::bool_constant{}); warpgroup_fence_operand(accum); - + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; // first k tile { @@ -611,7 +610,7 @@ struct CollectiveMma< copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,read_stage), tCrA_copy_view(_,_,0)); // transpose B operand in SMEM transpose(sB, gmma_sB, read_stage, 0); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { @@ -628,7 +627,7 @@ struct CollectiveMma< } warpgroup_wait<2>(); - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); @@ -667,14 +666,14 @@ struct CollectiveMma< copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); // transpose B operand in SMEM transpose(sB, gmma_sB, smem_pipe_read.index(), 0); - } + } else { copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); // transpose B operand in SMEM transpose.synchronize(k_block); // make transpose of k_block available transpose(sB, gmma_sB, read_stage, k_block + 1); } - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); @@ -700,7 +699,7 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { @@ -719,7 +718,7 @@ struct CollectiveMma< ++smem_pipe_release; } } - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); @@ -728,7 +727,7 @@ struct CollectiveMma< warpgroup_fence_operand(accum); } - + /// Perform a Consumer Epilogue to release all buffers CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { @@ -737,7 +736,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index a3f35c5c..4e435299 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -50,7 +50,6 @@ #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -102,7 +101,7 @@ public: ConvertAndScale, ConvertAndScaleWithZero }; - + // // Type Aliases // @@ -112,10 +111,10 @@ public: private: template friend struct detail::MixedInputUtils; - using CollectiveType = CollectiveMma; public: - static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," "[ElementZero]}. Inputs in [] are optional."); @@ -149,17 +148,17 @@ public: using NonVoidStrideScale = cute::conditional_t< cute::is_void_v, cute::Stride<_1, int64_t, int64_t>, StrideScale>; - static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value)) || + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value)) || (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value)), "The transformed type must be K-major."); static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || - ((cutlass::gemm::detail::is_k_major() || is_layout::value) && - (cutlass::gemm::detail::is_k_major() || is_layout::value)), + ((cutlass::gemm::detail::is_k_major() || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value)), "The unscaled element must be 2 bytes OR both inputs must be K-major"); - static_assert(cutlass::gemm::detail::is_mn_major(), + static_assert(cutlass::gemm::detail::is_mn_major(), "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); @@ -185,7 +184,7 @@ public: using SwappedSmemLayoutAtomB = cute::conditional_t; using SwappedSmemCopyAtomA = cute::conditional_t; using SwappedSmemCopyAtomB = cute::conditional_t; - + // TMA converts f32 input to tf32 when copying from GMEM to SMEM // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. static constexpr bool ConvertF32toTF32A = cute::is_same_v; @@ -238,10 +237,10 @@ public: using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), SwappedStrideA{})); using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), SwappedStrideB{})); - + // It is assumed that the scales and zero-points share the same smem layout using SmemLayoutScale = decltype(tile_to_shape( - SmemLayoutAtomScale{}, + SmemLayoutAtomScale{}, make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); @@ -260,11 +259,11 @@ public: static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); private: - static constexpr ConversionMode + static constexpr ConversionMode get_conversion_mode() { if constexpr (cute::is_void_v) { return ConversionMode::DirectConvert; - } + } else if constexpr (cute::is_void_v) { return ConversionMode::ConvertAndScale; } @@ -279,7 +278,7 @@ public: KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; - static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); @@ -424,7 +423,7 @@ public: uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB }; - } + } else if constexpr (ModeHasScales) { auto scale_k = ceil_div(K, args.group_size); ElementScale const* ptr_S = args.ptr_S; @@ -452,7 +451,7 @@ public: } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); } @@ -480,7 +479,7 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { check_mode_args = check_mode_args && (args.ptr_S == nullptr); check_mode_args = check_mode_args && (args.ptr_Z == nullptr); - } + } else if constexpr (ModeHasScales) { const int scale_mn = SwapAB ? N : M; const int scale_k = ceil_div(K, args.group_size); @@ -497,7 +496,7 @@ public: constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; check_aligned_Z = cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), args.dS); check_mode_args = check_mode_args && (args.ptr_Z != nullptr); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); } @@ -539,18 +538,18 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { // Nothing extra to do - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); } - + } /// Set up the data needed by this collective for load and mma. @@ -577,7 +576,7 @@ public: if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { return cute::make_tuple(gA_mkl, gB_nkl); - } + } else if constexpr (ModeHasScales) { auto scale_k = mainloop_params.scale_k; Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) @@ -593,11 +592,11 @@ public: else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); } - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); } - } + } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective @@ -609,7 +608,7 @@ public: CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -619,13 +618,13 @@ public: TensorStorage& shared_tensors) { if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); - } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); } @@ -638,7 +637,7 @@ public: // // Prepare the TMA loads for A, B and Scales // - + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; @@ -717,7 +716,7 @@ public: if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { // Nothing extra to do - } + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); @@ -725,8 +724,8 @@ public: } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); - } - } + } + } else { static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); } @@ -744,9 +743,9 @@ public: // Issue the epilogue waits if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -779,10 +778,10 @@ public: // Obtain warp index int warp_idx = canonical_warp_idx_sync(); [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; - + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // @@ -792,11 +791,11 @@ public: // Layout of warp group to thread mapping static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -819,7 +818,7 @@ public: return make_tensor_like(tCsA(_,_,_,Int<0>{})); } }(); - + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) @@ -871,14 +870,14 @@ public: barrier_token = pipeline.consumer_try_wait(smem_pipe_read); // copy smem->rmem for A operand - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); if (K_BLOCK_MAX > 1) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); } Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -889,25 +888,25 @@ public: warpgroup_commit_batch(); if (k_block < K_BLOCK_MAX - 2) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); } - } + } --k_tile_count; if (k_tile_count > 0) { // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. pipeline.consumer_wait(smem_pipe_read, barrier_token); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); if (K_BLOCK_MAX > 1) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); } - warpgroup_wait(); + warpgroup_wait(); Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); } } @@ -932,7 +931,7 @@ public: // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - + warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); @@ -949,19 +948,19 @@ public: barrier_token = pipeline.consumer_try_wait(smem_pipe_read); } - if (k_block == K_BLOCK_MAX - 1) { + if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_wait(smem_pipe_read, barrier_token); - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); if (K_BLOCK_MAX > 1) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); } Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); - } + } else { if (k_block < K_BLOCK_MAX - 2) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); @@ -981,7 +980,7 @@ public: int read_stage = smem_pipe_read.index(); warpgroup_fence_operand(accum); - + // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -999,7 +998,7 @@ public: } if (k_block < K_BLOCK_MAX - 2) { // prefetch next block - Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); } if (k_block < K_BLOCK_MAX - 1) { @@ -1019,7 +1018,7 @@ public: k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 324b397b..228c2589 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -41,7 +41,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -230,7 +229,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool implementable = true; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); @@ -401,14 +400,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -457,7 +456,7 @@ struct CollectiveMma< } CUTLASS_PRAGMA_UNROLL - for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count) - 1; + for (int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count) - 1; prologue_mma_count > 0; --prologue_mma_count) { // WAIT on smem_pipe_read until it's data is available diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index dbde656e..0e64bad5 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -41,7 +41,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -115,7 +114,7 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; // One threads per CTA are producers (1 for operand tile) - static constexpr int NumProducerThreadEvents = 1; + static constexpr int NumProducerThreadEvents = 1; static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -248,7 +247,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool implementable = true; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); @@ -400,9 +399,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -439,14 +438,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -567,7 +566,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index e29bd60d..c7ea65a6 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/tensor.hpp" #include "cute/numeric/arithmetic_tuple.hpp" @@ -244,7 +243,7 @@ struct CollectiveMma< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool implementable = true; constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); @@ -437,17 +436,17 @@ struct CollectiveMma< // // Define C accumulators and A/B partitioning // - + // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 19009d5d..ecbd59b5 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/detail/blockwise_scale_layout.hpp" @@ -166,7 +165,7 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); - // Block scaling gmem-to-smem copy atom + // Block scaling gmem-to-smem copy atom // we can have partial tiles in M or N, so don't vectorize those loads using CopyAtomSFA = Copy_Atom, ElementBlockScale>; using CopyAtomSFB = Copy_Atom, ElementBlockScale>; @@ -217,7 +216,7 @@ struct CollectiveMma< StrideA dA; ElementB const* ptr_B; StrideB dB; - ElementBlockScale const* ptr_SFA; + ElementBlockScale const* ptr_SFA; LayoutSFA layout_SFA; ElementBlockScale const* ptr_SFB; LayoutSFB layout_SFB; @@ -607,7 +606,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load_auxiliary( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -639,7 +638,7 @@ struct CollectiveMma< TiledCopy scale_copy_a = make_tiled_copy(CopyAtomSFA{}, Layout>{}, Layout>{}); - TiledCopy scale_copy_b = make_tiled_copy(CopyAtomSFB{}, + TiledCopy scale_copy_b = make_tiled_copy(CopyAtomSFB{}, Layout>{}, Layout>{}); ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(thread_idx); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(thread_idx); @@ -778,21 +777,21 @@ struct CollectiveMma< // Block scaling Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), make_layout( - make_shape(get<0>(shape(SmemLayoutSFA{})), - get<1>(TileShape{}), - make_shape(get<1>(shape(SmemLayoutSFA{})), + make_shape(get<0>(shape(SmemLayoutSFA{})), + get<1>(TileShape{}), + make_shape(get<1>(shape(SmemLayoutSFA{})), get<2>(shape(SmemLayoutSFA{})))), - make_stride(get<0>(stride(SmemLayoutSFA{})), _0{}, + make_stride(get<0>(stride(SmemLayoutSFA{})), _0{}, make_stride(get<1>(stride(SmemLayoutSFA{})), get<2>(stride(SmemLayoutSFA{})))) )); // (BLK_M,BLK_N,(BLK_K,P)) Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), make_layout( - make_shape(get<0>(TileShape{}), - get<0>(shape(SmemLayoutSFB{})), - make_shape(get<1>(shape(SmemLayoutSFB{})), + make_shape(get<0>(TileShape{}), + get<0>(shape(SmemLayoutSFB{})), + make_shape(get<1>(shape(SmemLayoutSFB{})), get<2>(shape(SmemLayoutSFB{})))), - make_stride(_0{}, - get<0>(stride(SmemLayoutSFB{})), - make_stride(get<1>(stride(SmemLayoutSFB{})), + make_stride(_0{}, + get<0>(stride(SmemLayoutSFB{})), + make_stride(get<1>(stride(SmemLayoutSFB{})), get<2>(stride(SmemLayoutSFB{})))) )); // (BLK_M,BLK_N,(BLK_K,P)) @@ -802,14 +801,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); diff --git a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp index d4f58c66..220e996a 100644 --- a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp @@ -42,7 +42,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -194,7 +193,7 @@ struct CollectiveMma< cute::conditional_t, cutlass::tfloat32_t, uint_bit_t>>>; - using TmaInternalElementB = cute::conditional_t, + using TmaInternalElementB = cute::conditional_t, tfloat32_t, uint_bit_t>>; @@ -215,7 +214,7 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 0; - static constexpr uint32_t TmaTransactionBytesMK = + static constexpr uint32_t TmaTransactionBytesMK = cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v); @@ -332,7 +331,7 @@ struct CollectiveMma< constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool size_check = true; // Check Alignment A if constexpr (is_A_mn_major) { @@ -405,7 +404,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -485,9 +484,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -523,14 +522,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -650,7 +649,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 0a57d31e..d993d9a1 100644 --- a/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -43,7 +43,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -195,7 +194,7 @@ struct CollectiveMma< cute::conditional_t, cutlass::tfloat32_t, uint_bit_t>>>; - using TmaInternalElementB = cute::conditional_t, + using TmaInternalElementB = cute::conditional_t, tfloat32_t, uint_bit_t>>; @@ -216,7 +215,7 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 0; - static constexpr uint32_t TmaTransactionBytesMK = + static constexpr uint32_t TmaTransactionBytesMK = cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutE{})) * cute::sizeof_bits_v); @@ -336,7 +335,7 @@ struct CollectiveMma< constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + bool size_check = true; // Check Alignment A if constexpr (is_A_mn_major) { @@ -416,7 +415,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load( Params const& mainloop_params, - MainloopPipeline pipeline, + MainloopPipeline pipeline, PipelineState smem_pipe_write, cute::tuple const& load_inputs, BlockCoord const& blk_coord, @@ -496,9 +495,9 @@ struct CollectiveMma< // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_write); @@ -534,14 +533,14 @@ struct CollectiveMma< // Layout of warp group to thread mapping - static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and - size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - Layout warp_group_thread_layout = make_layout(Int{}, + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); @@ -676,7 +675,7 @@ struct CollectiveMma< k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 712fc1ba..aa19fbc2 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -511,61 +511,82 @@ struct KernelPtrArrayTmaWarpSpecializedInputTransformSm100 final { // SM120 kernel schedules -template< int SchedulerPipelineStageCount_> +template struct KernelTmaWarpSpecializedCooperativeSm120 : KernelTmaWarpSpecializedCooperative { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelTmaWarpSpecializedPingpongSm120 : KernelTmaWarpSpecializedPingpong { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelTmaWarpSpecializedCooperativeBlockScaledSm120 : KernelTmaWarpSpecializedCooperative { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelTmaWarpSpecializedPingpongBlockScaledSm120 : KernelTmaWarpSpecializedPingpong { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; // SM120 dense Ptr-array kernel schedules -template< int SchedulerPipelineStageCount_> +template struct KernelPtrArrayTmaWarpSpecializedCooperativeSm120 : KernelPtrArrayTmaWarpSpecializedCooperative { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelPtrArrayTmaWarpSpecializedPingpongSm120 : KernelPtrArrayTmaWarpSpecializedPingpong { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelPtrArrayTmaWarpSpecializedCooperativeBlockScaledSm120 : KernelPtrArrayTmaWarpSpecializedCooperative { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; -template< int SchedulerPipelineStageCount_> +template struct KernelPtrArrayTmaWarpSpecializedPingpongBlockScaledSm120 : KernelPtrArrayTmaWarpSpecializedPingpong { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; }; // SM120 sparse kernel schedules -template< int SchedulerPipelineStageCount_, bool isAsymmetric_> +template struct KernelTmaWarpSpecializedCooperativeSparseSm120 { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; static constexpr bool isAsymmetric = isAsymmetric_; }; -template< int SchedulerPipelineStageCount_, bool isAsymmetric_> +template struct KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120 { static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; static constexpr bool isAsymmetric = isAsymmetric_; }; +// SM120 blockwise kernel schedules +template +struct KernelTmaWarpSpecializedCooperativeBlockwiseScalingSm120 : KernelTmaWarpSpecializedCooperative { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; +}; + +template +struct KernelTmaWarpSpecializedPingpongBlockwiseScalingSm120 : KernelTmaWarpSpecializedPingpong { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; +}; + +template +struct KernelPtrArrayTmaWarpSpecializedCooperativeBlockwiseScalingSm120 : KernelPtrArrayTmaWarpSpecializedCooperative { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; +}; + +template +struct KernelPtrArrayTmaWarpSpecializedPingpongBlockwiseScalingSm120 : KernelPtrArrayTmaWarpSpecializedPingpong { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; +}; + // Auxiliary Load Tag. namespace kernel::detail { @@ -776,6 +797,12 @@ struct KernelTmaWarpSpecializedMxf4Sm120 final : KernelScheduleMxNvf struct KernelTmaWarpSpecializedPingpongMxf4Sm120 final : KernelScheduleMxNvf4Sm120, KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120 final : KernelScheduleMxf8f6f4Sm120, KernelTmaWarpSpecializedPingpong { }; +// Blockwise Scaled GEMM +struct KernelScheduleSm120Blockwise: KernelScheduleSm120 { }; +struct KernelTmaWarpSpecializedBlockwiseCooperativeSm120 final : KernelScheduleSm120Blockwise, KernelTmaWarpSpecializedCooperative { }; +struct KernelTmaWarpSpecializedBlockwisePingpongSm120 final : KernelScheduleSm120Blockwise, KernelTmaWarpSpecializedPingpong { }; + + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM120 Sparse GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1120,6 +1147,43 @@ struct MainloopSm120TmaWarpSpecializedSparseBlockScaled { using Schedule = KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120; }; +template < + int Stages_, + int SchedulerPipelineStageCount_, + class ClusterShape_, + class KernelSchedule_ +> +struct MainloopSm120TmaWarpSpecializedBlockwiseScaling { + constexpr static int Stages = Stages_; + constexpr static int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + using ClusterShape = ClusterShape_; + using Schedule = KernelSchedule_; + + constexpr static int PipelineAsyncMmaStages = 0; + using ArchTag = arch::Sm120; +}; + +template < + int Stages_, + int SchedulerPipelineStageCount_, + class ClusterShape_, + class KernelSchedule_ +> +struct MainloopSm120ArrayTmaWarpSpecializedBlockwiseScaling { + constexpr static int Stages = Stages_; + constexpr static int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + using ClusterShape = ClusterShape_; + using Schedule = KernelSchedule_; + + constexpr static int PipelineAsyncMmaStages = 0; + using ArchTag = arch::Sm120; + + static_assert(cute::is_base_of_v || + cute::is_base_of_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative or Pingpong policies."); +}; + + ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index 73bdce50..fe5e4c53 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -74,7 +74,7 @@ struct GroupProblemShape { CUTLASS_HOST_DEVICE bool - is_host_problem_shape_available() { + is_host_problem_shape_available() const { return host_problem_shapes != nullptr; } }; @@ -113,7 +113,7 @@ public: CUTLASS_HOST_DEVICE bool - is_host_problem_shape_available() { + is_host_problem_shape_available() const { return true; } private: diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp index 78401097..738f460f 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp @@ -29,8 +29,6 @@ * **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -50,6 +48,7 @@ #include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp" #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" #include "cute/tensor.hpp" #include "cute/arch/tmem_allocator_sm100.hpp" @@ -190,7 +189,6 @@ public: // Kernel level shared memory storage struct SharedStorage { - // Barriers should be allocated in lower 8KB of SMEM for SM100 struct PipelineStorage : cute::aligned_struct<16, _1> { using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; @@ -206,7 +204,6 @@ public: alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; - alignas(16) arch::ClusterBarrier epilogue_throttle; } pipelines; alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; @@ -280,7 +277,12 @@ public: ProblemShape problem_shapes = args.problem_shape; // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; - if (!IsGroupedGemmKernel && sm_count != 0) { + if (IsGroupedGemmKernel && sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + else if (!IsGroupedGemmKernel && sm_count != 0) { CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); } @@ -487,7 +489,7 @@ public: : WarpCategory::Epilogue; uint32_t lane_predicate = cute::elect_one_sync(); - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); int cluster_size = size(cluster_shape); uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); bool is_first_cta_in_cluster = IsSchedDynamicPersistent ? (cta_rank_in_cluster == 0) : true; @@ -542,7 +544,7 @@ public: epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - epi_load_pipeline_params.initializing_warp = 4; + epi_load_pipeline_params.initializing_warp = 1; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); // Epilogue Store pipeline @@ -554,12 +556,11 @@ public: typename LoadOrderBarrier::Params load_order_barrier_params; load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; load_order_barrier_params.group_size = NumMainloopLoadThreads; - load_order_barrier_params.initializing_warp = 5; + load_order_barrier_params.initializing_warp = 3; LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); // CLC pipeline typename CLCPipeline::Params clc_pipeline_params; - if (WarpCategory::Sched == warp_category) { clc_pipeline_params.role = IsSchedDynamicPersistent ? CLCPipeline::ThreadCategory::ProducerConsumer : @@ -568,8 +569,7 @@ public: else { clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; } - - clc_pipeline_params.initializing_warp = 1; + clc_pipeline_params.initializing_warp = 4; clc_pipeline_params.producer_arv_count = 1; if constexpr (IsSchedDynamicPersistent) { @@ -608,7 +608,7 @@ public: // Only one producer thread arrives on this barrier. accumulator_pipeline_params.producer_arv_count = 1; accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; - accumulator_pipeline_params.initializing_warp = 2; + accumulator_pipeline_params.initializing_warp = 5; AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape, @@ -641,28 +641,20 @@ public: // Sync deallocation status between MMA warps of peer CTAs arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; - if constexpr(!IsOverlappingAccum) { - if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { - tmem_deallocation_result_barrier.init(NumMMAThreads); + if (WarpCategory::MMA == warp_category) { + if constexpr(!IsOverlappingAccum) { + if (has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } } - } - else { - if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { - tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + else { + if (has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } } - else if (WarpCategory::MMA == warp_category && lane_predicate) { - tmem_deallocation_result_barrier.init(NumEpilogueThreads); - } - } - - - // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. - arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; - if (WarpCategory::MMA == warp_category && lane_predicate) { - epilogue_throttle_barrier.init( NumMMAThreads + - (is_first_cta_in_cluster ? NumSchedThreads : 0) + - NumMainloopLoadThreads + - (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); } // We need this to guarantee that the Pipeline init is visible @@ -689,22 +681,17 @@ public: // Calculate mask after cluster barrier arrival mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); - accumulator_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); - auto work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); // // TMEM "Allocation" // - // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. - TiledMma tiled_mma; - auto acc_shape = collective_mainloop.partition_accumulator_shape(); - Tensor accumulators = cutlass::detail::make_sm100_accumulator( - tiled_mma, acc_shape, EpilogueTile{}); - + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); pipeline_init_wait(cluster_size); if constexpr (IsGroupedGemmKernel) { @@ -719,16 +706,17 @@ public: auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); if (is_participant.main_load) { + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, + shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id, work_tile_info.L_idx); + // Ensure that the prefetched kernel does not touch // unflushed global memory prior to this instruction cutlass::arch::wait_on_dependent_grids(); bool do_load_order_arrive = is_epi_load_needed; - auto load_inputs = collective_mainloop.load_init( - problem_shape_MNKL, params.mainloop, - shared_storage.tensors.mainloop, - shared_storage.tensormaps.mainloop, - params.hw_info.sm_count, sm_id, work_tile_info.L_idx); Tensor gA_mkl = get<0>(load_inputs); // Fetch a copy of tensormaps for the CTA from Params auto input_tensormaps = get(load_inputs); @@ -737,9 +725,6 @@ public: // Even the first tile for a CTA can be from any of the batches. // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. bool did_batch_change = true; - - // Signal the epilogue warps to proceed once the prologue is complete - epilogue_throttle_barrier.arrive(); bool requires_clc_query = true; do { @@ -824,9 +809,6 @@ public: } else if (is_participant.sched) { - // Signal the epilogue warps to proceed once the prologue is complete - epilogue_throttle_barrier.arrive(); - // Grouped GEMM uses static tile scheduler if constexpr (IsSchedDynamicPersistent) { // Whether a new CLC query must be performed. @@ -891,18 +873,8 @@ public: __syncwarp(); tmem_allocation_result_barrier.arrive(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - accumulators.data() = tmem_base_ptr; - int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); - - - auto mma_inputs = collective_mainloop.mma_init( - params.mainloop, - collective_mainloop.slice_accumulator(accumulators, 0), - shared_storage.tensors.mainloop, - tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); - - // Signal the epilogue warps to proceed once the prologue is complete - epilogue_throttle_barrier.arrive(); + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + auto mma_inputs = collective_mainloop.mma_init(tmem_storage, shared_storage.tensors.mainloop); do { @@ -917,30 +889,29 @@ public: ++clc_pipe_consumer_state; } - // Wait for tmem accumulator buffer to become empty with a flipped phase - if constexpr (!IsOverlappingAccum) { - if (is_mma_leader_cta) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - } - } - if constexpr (IsGroupedGemmKernel) { problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); - int acc_stage = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); - auto accumulator = collective_mainloop.slice_accumulator(accumulators, acc_stage); + // Accumulator stage slice + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_producer_state.phase() ^ 1; + } + else { + return accumulator_pipe_producer_state.index(); + } + }(); + auto accumulator = collective_mainloop.slice_accumulator(tmem_storage, acc_stage); if (is_mma_leader_cta) { mainloop_pipe_consumer_state = collective_mainloop.mma( - cute::make_tuple( - mainloop_pipeline, accumulator_pipeline), - cute::make_tuple( - mainloop_pipe_consumer_state, accumulator_pipe_producer_state), + cute::make_tuple(mainloop_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_pipe_consumer_state, accumulator_pipe_producer_state), accumulator, mma_inputs, cta_coord_mnkl, k_tile_count - ); + ); accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); } ++accumulator_pipe_producer_state; @@ -969,7 +940,6 @@ public: tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); } - } else { tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); @@ -997,9 +967,6 @@ public: bool did_batch_change = true; constexpr bool IsEpiLoad = true; - // Signal the epilogue warps to proceed once the prologue is complete - epilogue_throttle_barrier.arrive(); - do { int32_t curr_batch = work_tile_info.L_idx; if (did_batch_change) { @@ -1069,14 +1036,10 @@ public: } else if (is_participant.epilogue) { - // Throttle the epilogue warps to improve prologue performance - static constexpr int epilogue_throttle_phase_bit = 0; - epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); - // Wait for tmem allocate here tmem_allocation_result_barrier.arrive_and_wait(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - accumulators.data() = tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); bool do_tail_store = false; @@ -1110,7 +1073,7 @@ public: ++clc_pipe_consumer_state; } - // Accumulator stage slice after making sure allocation has been performed + // Accumulator stage slice int acc_stage = [&] () { if constexpr (IsOverlappingAccum) { return accumulator_pipe_consumer_state.phase(); @@ -1119,6 +1082,7 @@ public: return accumulator_pipe_consumer_state.index(); } }(); + auto accumulator = collective_mainloop.slice_accumulator(tmem_storage, acc_stage); // Fusions may need problem shape for the current group if constexpr (IsGroupedGemmKernel) { @@ -1139,7 +1103,7 @@ public: cta_coord_mnkl, TileShape{}, TiledMma{}, - collective_mainloop.slice_accumulator(accumulators, acc_stage), + accumulator, shared_storage.tensors.epilogue, cute::make_tuple(epi_store_tensormap, did_batch_change) ); @@ -1175,7 +1139,6 @@ public: } else { - } } }; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp index 83eebaf5..2ec1049b 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp @@ -289,7 +289,12 @@ public: ProblemShape problem_shapes = args.problem_shape; // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; - if (!IsGroupedGemmKernel && sm_count != 0) { + if (IsGroupedGemmKernel && sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + else if (!IsGroupedGemmKernel && sm_count != 0) { CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); } diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp index 3989ffe3..85f87af2 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -172,6 +172,9 @@ public: using CLCPipeline = cutlass::PipelineCLCFetchAsync; using CLCPipelineState = typename CLCPipeline::PipelineState; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; @@ -183,12 +186,14 @@ public: using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; using CLCPipelineStorage = typename CLCPipeline::SharedStorage; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) LoadOrderBarrierStorage load_order; alignas(16) CLCPipelineStorage clc; alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; } pipelines; @@ -530,6 +535,22 @@ public: cute::true_type{}, // Perform barrier init cute::false_type{}); // Delay mask calculation + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + // Tmem allocator TmemAllocator tmem_allocator{}; @@ -599,6 +620,7 @@ public: cutlass::arch::wait_on_dependent_grids(); bool do_load_order_arrive = is_epi_load_needed; + bool requires_clc_query = true; do { // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. @@ -606,6 +628,14 @@ public: auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( mainloop_pipeline, @@ -639,6 +669,7 @@ public: ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } @@ -658,6 +689,11 @@ public: do { if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + // Query next clcID and update producer state clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); } diff --git a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp index 0932f5c6..a5f6eb9b 100644 --- a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp @@ -178,6 +178,9 @@ public: using CLCPipeline = cutlass::PipelineCLCFetchAsync; using CLCPipelineState = typename CLCPipeline::PipelineState; + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; @@ -190,12 +193,14 @@ public: using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; using CLCPipelineStorage = typename CLCPipeline::SharedStorage; using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) LoadOrderBarrierStorage load_order; alignas(16) CLCPipelineStorage clc; alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; alignas(16) arch::ClusterBarrier tmem_dealloc; } pipelines; @@ -580,6 +585,22 @@ public: cute::true_type{}, // Perform barrier init cute::false_type{}); // Delay mask calculation + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + // Tmem allocator TmemAllocator tmem_allocator{}; @@ -649,12 +670,21 @@ public: cutlass::arch::wait_on_dependent_grids(); bool do_load_order_arrive = is_epi_load_needed; + bool requires_clc_query = true; do { // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, load_inputs.k_tiles); auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads auto [mainloop_producer_state_next, unused_] = collective_mainloop.load( mainloop_pipeline, @@ -678,6 +708,7 @@ public: ); work_tile_info = next_work_tile_info; cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; if (increment_pipe) { ++clc_pipe_consumer_state; } @@ -697,6 +728,11 @@ public: do { if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + // Query next clcID and update producer state clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); } diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index 1c92efbc..b5538a72 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -59,8 +59,7 @@ class PersistentTileSchedulerSm100Group { public: using UnderlyingScheduler = PersistentTileSchedulerSm90Group; - using UnderlyingProblemShape = typename GroupProblemShape::UnderlyingProblemShape; - using Params = PersistentTileSchedulerSm100GroupParams; + using Params = PersistentTileSchedulerSm100GroupParams; using WorkTileInfo = typename UnderlyingScheduler::WorkTileInfo; using Arguments = typename UnderlyingScheduler::Arguments; using RasterOrder = typename Params::RasterOrder; @@ -94,7 +93,6 @@ public: shape_div(tile_shape_mnk, selected_cluster_shape)); // Static Cluster: Blackwell builders expects TileShape to be Cluster's Tile Shape, Hopper doesn't dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes.groups(), problem_shapes, hw_info, cta_shape, selected_cluster_shape); @@ -102,9 +100,7 @@ public: Params params; params.initialize( problem_blocks, - problem_shapes.groups(), - problem_shapes.problem_shapes, - problem_shapes.host_problem_shapes, + problem_shapes, to_gemm_coord(cta_shape), to_gemm_coord(selected_cluster_shape), hw_info, @@ -144,8 +140,8 @@ public: template CUTLASS_HOST_DEVICE static dim3 - get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { - return UnderlyingScheduler::get_tiled_cta_shape_mnl(groups, problem_shapes, hw_info, cta_shape, cluster_shape); + get_tiled_cta_shape_mnl(GroupProblemShape const &problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { + return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shapes, hw_info, cta_shape, cluster_shape); } // Given the inputs, computes the physical grid we should launch. @@ -154,13 +150,12 @@ public: static dim3 get_grid_shape( Params const& params, - GroupProblemShape problem_shapes, + GroupProblemShape const& problem_shapes, BlockShape cta_shape, [[maybe_unused]] AtomThrShape atom_thr_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) { dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes.groups(), problem_shapes, hw_info, cta_shape, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index dc5610fc..0b12aacd 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -442,7 +442,8 @@ public: // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { + if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop || + producer_warp_role == ProducerWarpRole::MainloopAux)) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index a3717853..fc4f5fc1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -447,7 +447,8 @@ public: // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { + if (warp_group_role == WarpGroupRole::Producer && (producer_warp_role == ProducerWarpRole::Mainloop + || producer_warp_role == ProducerWarpRole::MainloopAux)) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index bf02e5fa..92749b19 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -94,7 +94,7 @@ public: }; using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; - using Params = PersistentTileSchedulerSm90GroupParams; + using Params = PersistentTileSchedulerSm90GroupParams; using RasterOrder = typename Params::RasterOrder; using RasterOrderOptions = typename Params::RasterOrderOptions; static constexpr bool IsDynamicPersistent = false; @@ -160,7 +160,6 @@ public: static_assert(cute::is_static::value); dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes.groups(), problem_shapes, hw_info, tile_shape, cluster_shape); @@ -168,9 +167,7 @@ public: Params params; params.initialize( problem_blocks, - problem_shapes.groups(), - problem_shapes.problem_shapes, - problem_shapes.host_problem_shapes, + problem_shapes, to_gemm_coord(tile_shape), to_gemm_coord(cluster_shape), hw_info, @@ -187,7 +184,7 @@ public: dim3 get_grid_shape( [[maybe_unused]] Params const& params, - GroupProblemShape problem_shapes, + GroupProblemShape const& problem_shapes, TileShape tile_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info, @@ -195,7 +192,6 @@ public: bool truncate_by_problem_size=true) { dim3 problem_blocks = get_tiled_cta_shape_mnl( - problem_shapes.groups(), problem_shapes, hw_info, tile_shape, cluster_shape); @@ -215,7 +211,8 @@ public: template CUTLASS_HOST_DEVICE static dim3 - get_tiled_cta_shape_mnl(int groups, GroupProblemShape problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { + get_tiled_cta_shape_mnl(GroupProblemShape const& problem_shapes, KernelHardwareInfo hw_info, BlockShape cta_shape, ClusterShape cluster_shape) { + int groups = problem_shapes.groups(); uint32_t total_ctas = 0; uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here @@ -259,20 +256,21 @@ public: } int lane_idx = canonical_lane_idx(); - if (lane_idx < params_.groups_) { - cached_problem_shapes_[1] = params_.problem_shapes_[lane_idx]; + if (lane_idx < params_.problem_shapes_.groups()) { + cached_problem_shapes_[1] = params_.problem_shapes_.get_problem_shape(lane_idx); } total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); uint64_t ctas_along_m, ctas_along_n; - if (is_tuple(params_.problem_shapes_[0]))>::value || - is_tuple(params_.problem_shapes_[0]))>::value) { - ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.m())); - ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), scheduler_params.cta_shape_.n())); + ProblemShape problem_shape = params_.problem_shapes_.get_problem_shape(0); + if (is_tuple(problem_shape))>::value || + is_tuple(problem_shape))>::value) { + ctas_along_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape), scheduler_params.cta_shape_.m())); + ctas_along_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape), scheduler_params.cta_shape_.n())); } else { - ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_m_.divisor - 1); - ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(params_.problem_shapes_[0]) + scheduler_params.divmod_cta_shape_n_.divisor - 1); + ctas_along_m = scheduler_params.divmod_cta_shape_m_.divide(cute::shape<0>(problem_shape) + scheduler_params.divmod_cta_shape_m_.divisor - 1); + ctas_along_n = scheduler_params.divmod_cta_shape_n_.divide(cute::shape<1>(problem_shape) + scheduler_params.divmod_cta_shape_n_.divisor - 1); } auto problem_blocks_m = round_up(ctas_along_m, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.m()); auto problem_blocks_n = round_up(ctas_along_n, (1 << params_.log_swizzle_size_) * params_.cluster_shape_.n()); @@ -292,8 +290,7 @@ public: get_work_idx_m_and_n( uint64_t linear_idx, GroupInfo& group_info, - int32_t total_problem_groups, - ProblemShape* problem_shapes, + GroupProblemShape &problem_shapes, ProblemShape (&cached_problem_shapes)[2], GemmCoord cta_shape, GemmCoord cluster_shape, @@ -308,13 +305,14 @@ public: // Use a warp to "speculatively" check if the work tile maps to the next 32 groups int lane_idx = canonical_lane_idx(); + int total_problem_groups = problem_shapes.groups(); if (linear_idx >= group_info.total_tiles + group_info.start_linear_idx) { group_info.group_idx += lane_idx; for ( ; ; group_info.group_idx += NumThreadsPerWarp) { cached_problem_shapes[0] = cached_problem_shapes[1]; if (group_info.group_idx + NumThreadsPerWarp < total_problem_groups) { - cached_problem_shapes[1] = problem_shapes[group_info.group_idx + NumThreadsPerWarp]; + cached_problem_shapes[1] = problem_shapes.get_problem_shape(group_info.group_idx + NumThreadsPerWarp); } if (group_info.group_idx < total_problem_groups) { uint64_t ctas_along_m, ctas_along_n; @@ -354,7 +352,7 @@ public: group_info.total_tiles = __shfl_sync(0xffffffff, group_info.total_tiles, first_succeeding_thread); group_info.problem_blocks_along_raster_order = __shfl_sync(0xffffffff, group_info.problem_blocks_along_raster_order, first_succeeding_thread); if (group_info.group_idx + lane_idx < total_problem_groups) { - cached_problem_shapes[1] = problem_shapes[group_info.group_idx + lane_idx]; + cached_problem_shapes[1] = problem_shapes.get_problem_shape(group_info.group_idx + lane_idx); } break; } @@ -419,7 +417,6 @@ public: return get_work_idx_m_and_n( linear_idx, current_group_info_, - scheduler_params.groups_, scheduler_params.problem_shapes_, cached_problem_shapes_, scheduler_params.cta_shape_, diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index 437f5af2..a298e06b 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -1083,7 +1083,6 @@ private: // output tile. This work will thus be subsumed by the previous stream-K unit. --unit_idx; } - return unit_idx; }; diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 0a837bed..9a89bf21 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -59,6 +59,14 @@ get_max_cta_occupancy(int max_sm_per_gpc, GemmCoord cluster_shape, int sm_count) int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; + // Suppose max_sm_per_gpc = 20, cluster_size = 8, sm_count = 148 + // min_num_gpc = 148 / 20 = 7 + // max_cta_occupancy_per_gpc = 20 - (20 % 8) = 16 + // cta_per_device = 7 * 16 = 112 + // num_gpc_residual = 148 % 20 = 8 + // max_cta_occupancy_per_residual_gpc = 8 - (8 % 8) = 8 + // cta_per_device += 8 = 120 + // cta_per_device = 120 < 148 ? 148 : 120 = 148 // The calculation below allows for larger grid size launch for different GPUs. int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; @@ -658,7 +666,6 @@ struct PersistentTileSchedulerSm90StreamKParams { // number of K tiles per stream-K unit remains above min_iters_per_sk_unit_ uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_)); - // Grouping is disabled when separate reduction is used because grouping is primarily an attempt // to improve L2 locality, and L2-locality optimizations are unnecessary when the the kernel // is a single wave (which is the case for separate reduction). @@ -1616,19 +1623,10 @@ struct PersistentTileSchedulerSm90StreamKParams { //////////////////////////////////////////////////////////////////////////////// // Parameters for SM90 persistent group scheduler (only used for Grouped Gemms) -template +template struct PersistentTileSchedulerSm90GroupParams { - - enum class RasterOrder { - AlongM, - AlongN - }; - - enum class RasterOrderOptions { - Heuristic, - AlongM, - AlongN - }; + using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; FastDivmodU64Pow2 divmod_cluster_shape_major_{}; FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; @@ -1640,8 +1638,7 @@ struct PersistentTileSchedulerSm90GroupParams { int32_t log_swizzle_size_ = 0; RasterOrder raster_order_ = RasterOrder::AlongN; - int32_t groups_ = 0; - ProblemShape* problem_shapes_ = nullptr; + GroupProblemShape problem_shapes_; GemmCoord cta_shape_; GemmCoord cluster_shape_; @@ -1651,9 +1648,7 @@ struct PersistentTileSchedulerSm90GroupParams { void initialize( dim3 problem_blocks, - int32_t groups, - ProblemShape* problem_shapes, - ProblemShape const* host_problem_shapes, + GroupProblemShape problem_shapes, GemmCoord cta_shape, GemmCoord cluster_shape, KernelHardwareInfo const& hw_info, @@ -1677,13 +1672,12 @@ struct PersistentTileSchedulerSm90GroupParams { // // Set members // - groups_ = groups; problem_shapes_ = problem_shapes; cta_shape_ = cta_shape; cluster_shape_ = cluster_shape; blocks_across_problem_ = problem_blocks.x * problem_blocks.y * problem_blocks.z; - pre_processed_problem_shapes = (host_problem_shapes == nullptr) ? false : true; + pre_processed_problem_shapes = problem_shapes.is_host_problem_shape_available(); log_swizzle_size_ = log_swizzle_size; raster_order_ = raster_order; @@ -2442,12 +2436,12 @@ struct PersistentTileSchedulerSm100StreamKParams { //////////////////////////////////////////////////////////////////////////////// // Parameters for SM100 persistent group scheduler (only used for Grouped Gemms) -template +template struct PersistentTileSchedulerSm100GroupParams { - using UnderlyingSm90Params = PersistentTileSchedulerSm90GroupParams; - using RasterOrder = typename UnderlyingSm90Params::RasterOrder; - using RasterOrderOptions = typename UnderlyingSm90Params::RasterOrderOptions; + using UnderlyingSm90Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = cutlass::gemm::kernel::detail::RasterOrder; + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; UnderlyingSm90Params params_sm90_{}; @@ -2457,9 +2451,7 @@ struct PersistentTileSchedulerSm100GroupParams { void initialize( dim3 problem_blocks, - int32_t groups, - ProblemShape* problem_shapes, - ProblemShape const* host_problem_shapes, + GroupProblemShape problem_shapes, GemmCoord cta_shape, GemmCoord cluster_shape, KernelHardwareInfo const& hw_info, @@ -2469,9 +2461,7 @@ struct PersistentTileSchedulerSm100GroupParams { params_sm90_.initialize( problem_blocks, - groups, problem_shapes, - host_problem_shapes, cta_shape, cluster_shape, hw_info, diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 33c7d8ec..886dc9f2 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -3968,7 +3968,6 @@ struct NumericArrayConverter { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Array <= Array @@ -4377,6 +4376,123 @@ namespace detail { } ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round, + int N +> +struct NumericArrayConverter { + using result_element = cutlass::half_t; + using source_element = cutlass::float_e2m1_t; + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + CUTLASS_DEVICE + static result_type_packed_8 ptx_convert(source_type_packed_8 const &source) { + result_type_packed_8 out; + uint32_t* out_fp16 = reinterpret_cast(&out); + uint32_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1, byte2, byte3;\n" \ + "mov.b32 {byte0, byte1, byte2, byte3}, %4;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "cvt.rn.f16x2.e2m1x2 %2, byte2;\n" \ + "cvt.rn.f16x2.e2m1x2 %3, byte3;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) , "=r"(out_fp16[2]), "=r"(out_fp16[3]): "r"(src_packed)); + return out; + } + + CUTLASS_DEVICE + static result_type_packed_4 ptx_convert(source_type_packed_4 const &source) { + result_type_packed_4 out; + uint32_t* out_fp16 = reinterpret_cast(&out); + uint16_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %2;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "cvt.rn.f16x2.e2m1x2 %1, byte1;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "h"(src_packed)); + return out; + } + + CUTLASS_DEVICE + static result_type_packed_2 ptx_convert(source_type_packed_2 const &source) { + result_type_packed_2 out; + uint32_t* out_fp16 = reinterpret_cast(&out); + uint16_t const& src_packed = static_cast(reinterpret_cast(source)); + asm volatile( \ + "{\n" \ + ".reg .b8 byte0, byte1;\n" \ + "mov.b16 {byte0, byte1}, %1;\n" \ + "cvt.rn.f16x2.e2m1x2 %0, byte0;\n" \ + "}\n" : "=r"(out_fp16[0]) : "h"(src_packed)); + return out; + } + #endif + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + return ptx_convert(source); + #else + PackedResultType result; + NumericConverter converter; + + const int k_packed = PackedResultType::kElements; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < k_packed; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; /// Partial specialization for Array <= Array template diff --git a/media/docs/cpp/blackwell_cluster_launch_control.md b/media/docs/cpp/blackwell_cluster_launch_control.md index 1504c144..28b5e464 100644 --- a/media/docs/cpp/blackwell_cluster_launch_control.md +++ b/media/docs/cpp/blackwell_cluster_launch_control.md @@ -6,7 +6,7 @@ A GEMM workload usually consists of three phases: prologue, mainloop and epilogu Consider a GEMM that has `20x20x1` output tiles, running on a GPU with `100` SMs. There is another kernel occupying all the resources of `20` SMs so only `80` SMs can be used. Assume cluster shape is `1x1x1`. The following diagram shows how the schedule would look like for such a kernel. -

A beautiful sunset

+![GEMM tiles are evenly divided among available SMs](../../images/non_persistent.png "GEMM Scheduling with Limited SM Resources") ### Static Scheduler @@ -14,7 +14,7 @@ CUTLASS has adopted a software technique named **persistent kernels**. Persisten However, static scheduler is susceptible to workload imbalance if the resources of some SMs are unavailable. The following diagram illustrates this issue. -

A beautiful sunset

+![GEMM tiles are unevenly divided among available SMs, leading to workload imbalance](../../images/persistent_static.png "Imbalanced Workload Scheduling due to Static Scheduler") ### Dynamic Scheduler with Cluster Launch Control A fundamental limitation of persistent scheduling is that the number of SMs this kernel can utilize is unknown in real time. Some SMs might be occupied by another kernel and thus their resources are unavailable. This makes it challenging to load-balance work across SMs. @@ -32,7 +32,7 @@ Cluster launch control follows the below rules: The following diagram shows how the schedule would look like with cluster launch control. -

A beautiful sunset

+![GEMM tiles are dynamically allocated among available SMs, leading to a balanced workload](../../images/persistent_clc.png "Dynamic Scheduler with Cluster Launch Control") ## Programming Model ### Pseudo Code diff --git a/media/docs/cpp/cute/02_layout_algebra.md b/media/docs/cpp/cute/02_layout_algebra.md index 03578313..79a59abf 100644 --- a/media/docs/cpp/cute/02_layout_algebra.md +++ b/media/docs/cpp/cute/02_layout_algebra.md @@ -142,8 +142,8 @@ Put into words, `A o B = A o s:d`, for integral `s` and `d` means that we want ( * `(6,2) / 3 => (2,2)` * `(6,2) / 6 => (1,2)` * `(6,2) / 12 => (1,1)` -* `(3,6,2,8) / 3 => (1,6,2,8)` -* `(3,6,2,8) / 6 => (1,3,2,8)` +* `(3,6,2,8) / 3 => (1,3,2,8)` +* `(3,6,2,8) / 6 => (1,6,2,8)` * `(3,6,2,8) / 9 => (1,2,2,8)` * `(3,6,2,8) / 72 => (1,1,1,4)` diff --git a/media/docs/cpp/cute/0y_predication.md b/media/docs/cpp/cute/0y_predication.md index faa15156..303b167e 100644 --- a/media/docs/cpp/cute/0y_predication.md +++ b/media/docs/cpp/cute/0y_predication.md @@ -10,53 +10,58 @@ For example, we might want to tile a 41 x 55 matrix into 4 x 8 tiles, but 41 / 4 is 10 remainder 1, and 55 / 8 is 6 remainder 7. What do we do with those "leftover" parts of the matrix? -Another way to say this, is that `logical_divide` +To start, we note that `logical_divide` (CuTe's way of tiling layouts) "rounds up." -For example, if `N` is the layout (1000, 1) and `B` is the layout (128, 1), -then `logical_divide(N, B)` is the layout ((128, 8), (1, 128)). -This effectively rounds up the original shape N = 1000 -into an 128 x 8 matrix (as if N = 1024). +For example, if `N` is the layout `1000:1` and `B` is the layout `128:1`, +then `logical_divide(N, B)` is the layout `(128, 8):(1, 128)`. +This effectively rounds up the original shape `N = 1000` +into an `128 x 8` matrix (as if `N = 1024`). What about those last 24 elements, -that aren't part of the original data? +that aren't part of the original data? How is the last tile handled and how do we avoid indexing out-of-bounds? -The idiomatic CuTe way to solve this problem is through "predication." -Rather than trying to reason about the "remainder tiles," -CuTe instead rounds up, but only tries to access data in each tile -that are part of the matrix. +Like other introductions to CUDA programming, the idiomatic CuTe way to address these issues is through "predication." +Rather than attempting to reason about the "remainder tiles" by trying to represent "7 tiles of size-128 and 1 tile of size-104," +CuTe instead rounds up to "8 tiles of size-128" and constructs predicates so that the kernel +only tries to access data in each tile that are valid within the matrix. This corresponds well with how our GPUs optimize: branches without warp divergence are relatively fast. It also matches the usual CUDA idiom when dividing N work items in 1-D fashion over B thread blocks: first test if "my thread" is out of bounds before doing work. -There are a few ways to figure out -which elements need to be predicated. -In-kernel GEMMs like to do this in the following way. +Consider a generic tiling wherein a size-1000 vector is tiled into size-128 chunks. Then a predication tensor can be constructed as follows: ```c++ -// Create the predicate tensor -Layout idA = make_layout(shape(A)); // e.g. 1000:1 -Layout idAB = logical_divide(idA, B); // e.g. (128,8):(1,128) +Tensor gmem = ... // e.g. size 1000 +Tensor smem = ... // e.g. size 128 -Tensor pred = make_tensor(shape(idAB)); +// Tile the gmem for smem +Tensor gmem_tiled = logical_divide(gmem, size(smem)); // e.g. (128,8) + +// Create an identity layout for gmem and tile it similarly +Layout id_layout = make_layout(shape(gmem)); // e.g. 1000:1, explicitly constructed as identity function +Layout id_tiled = logical_divide(id_layout, size(smem)); // e.g. (128,8):(1,128), but many elements aren't "valid" + +// Create a predicate tensor +Tensor pred = make_tensor(shape(id_tiled)); // e.g. (128,8) for (int i = 0; i < size(pred); ++i) { - pred(i) = idAB(i) < size(A); + pred(i) = id_tiled(i) < size(id_layout); // Predicate: Is the offset within the original shape? } // ... intervening code ... -// Use the predicate tensor. c is some coordinate. -// This code would likely live inside some algorithm. -if (pred(c)) { copy(idAB(c), smem(c)); } +// Note that gmem_tiled, id_tiled, and pred tensors are all congruent +// For tile tile_i, determine if element value_j is in-bounds and copy to smem +if (pred(value_j,tile_i)) { smem(value_j) = gmem_tiled(value_j,tile_i); } ``` The general procedure is that we -1. create an "identity" layout (`Layout idA = make_layout(shape(A))`, +1. create an "identity" layout (`Layout id_layout = make_layout(shape(gmem))`, in the above example) with the same shape as our original data; 2. repeat the same tiling/partitioning/slicing (possibly rounding up) - on that identity layout (`Layout idAB = logical_divide(idA, B)`); + on that identity layout (`Layout id_tiled = logical_divide(id_layout, size(smem));`); 3. create a "predicate tensor" by comparing the coordinates of that reference layout with the bounds of the original layout; @@ -64,19 +69,119 @@ The general procedure is that we 4. use the predicate tensor to mask off accesses to out-of-bounds elements. -For example, suppose that we've partitioned A and B tiles -across threads as follows. +As a relatively simple example, consider predicating the epilogue of a GEMM. +Suppose that we've partitioned `mC` into cta tiles and across threads of an mma as follows. -```c++ -Tensor tAgA = local_partition(gA, tA, thread_idx); // (THR_M,THR_K,k) -Tensor tAsA = local_partition(sA, tA, thread_idx); // (THR_M,THR_K,PIPE) +```cpp +// CTA partitioning +auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) +Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) -Tensor tBgB = local_partition(gB, tB, thread_idx); // (THR_N,THR_K,k) -Tensor tBsB = local_partition(sB, tB, thread_idx); // (THR_N,THR_K,PIPE) +// Thread partitioning +auto thr_mma = mma.get_slice(threadIdx.x); +Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) +Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + +// ... Compute gemms and accumulate into tCrC ... + +// axpby epilogue +for (int i = 0; i < size(tCgC); ++i) { + tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); +} ``` -`tAgA` and `tBgB` partition the global A resp. B matrices over threads, -and `tAsA` and `tBsB` partition the shared memory tiles of A resp. B over threads. +Then, following the predication procedure is straightforward, + +```cpp +// A coordinate tensor the same shape as mC: (m,n) -> (m,n) +Tensor cC = make_identity_tensor(shape(mC)); + +// Repeat partitioning steps applied to mC to our coordinate tensor cC +// CTA partitioning +Tensor cta_cC = local_tile(cC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) -> (m,n) +// Thread partitioning +Tensor tCcC = thr_mma.partition_C(cta_cC); // (MMA,MMA_M,MMA_N) -> (m,n) + +// Predicated axpby epilogue +for (int i = 0; i < size(tCgC); ++i) { + if (elem_less(tCcC(i), shape(mC))) { // if coord is in-bounds + tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); + } +} +``` + +Above, the cta is responsible for tiling/partitioning `mC` and the mma is responsible for tiling/partitioning `gC`, +so both steps are also applied to the identity tensor. +The coordinate tensor `tCcC` is congruent with the register fragment `tCrC` and the partitioned global memory tensor `tCgC`, which are this threads' subtensors of the tile of data. However, the `tCcC` tensor retains it's original codomain when evaluated: a global coordinate into the original tensor `mC`. This global coordinate is compared to the shape of `mC` to determine validity of the operation. + +Advantages of this "reference identity tensor" or "coordinate tensor" approach include: + +1. There is no dependence on the layout/strides of the tensor + being predicated, just the logical bounds imposed. + +2. The partitioning stage(s) can be anything. A CTA tiling, a thread partitioning, a TiledMMA, and a TiledCopy can all be applied to any tensor, including a coordinate tensor. + +3. It naturally extends to any-dimensional predication. + +4. It's a natural generalization of a typical CUDA 1-D + parallel vector access pattern, + which computes an access index `idx` and predicates access to the vector's `idx`-th element, determining if `idx` is in-bounds. +```cpp +int idx = blockDim.x * blockIdx.x + threadIdx.x; +if (idx < N) // idx is a "coord" into gmem and N is the "bound" + gmem_ptr[idx] = ...; +``` + +In a SIMT programming model, the tensor extents should not be modified so that loops don't overrun. +Instead, predication is a general method to query the original coordinate and determine if that coordinate overruns. +This avoids variable/dynamic loop bounds in favor of instruction-level predication, preservation of thread coherence, and maintaining load balance. +It's also general enough to extend to all ranks, all layouts of threads and data, and all tiling/partitioning patterns. +Assumptions can be built into the coordinate tensors or the predicate tensors to account for special cases. + +As another slightly more complex example, consider the m- and n-predication of A and B loads in a GEMM. Suppose that we've partitioned A and B tiles across ctas and threads as follows. + +```c++ +// CTA partitioning +auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) +Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) +Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + +Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) +Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) + +// Thread partitioning +Tensor tAgA = local_partition(gA, tA, thread_idx); // (THR_M,THR_K,k) +Tensor tAsA = local_partition(sA, tA, thread_idx); // (THR_M,THR_K) + +Tensor tBgB = local_partition(gB, tB, thread_idx); // (THR_N,THR_K,k) +Tensor tBsB = local_partition(sB, tB, thread_idx); // (THR_N,THR_K) +``` + +`gA` and `gB` are tiles of `mA` resp. `mB` according to `cta_tiler` and the `cta_coord`. +`tAgA` and `tBgB` are partitions of `gA` resp. `gB` according the the thread-layouts `tA` and `tB` +and `thread_idx`. + +The following code creates "identity tensors" that map coordinates `(m,k) -> (m,k)` and `(n,k) -> (n,k)`. + +```c++ +// Coordinate tensors +Tensor cA = make_identity_tensor(shape(mA)); // (m,k) -> (m,k) +Tensor cB = make_identity_tensor(shape(mB)); // (n,k) -> (n,k) +``` + +Then, the reference tensors are tiled and partitioned +in exactly the same way the `mA` and `mB` tensors were tiled and partitioned +into `tAgA` and `tBgB`. + +```c++ +// CTA partitioning +Tensor cta_cA = local_tile(cA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) -> (m,k) +Tensor cta_cB = local_tile(cB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) -> (n,k) + +// Thread partitioning +Tensor tAcA = local_partition(cta_cA, tA, thread_idx); // (THR_M,THR_K,k) -> (m,k) +Tensor tBcB = local_partition(cta_cB, tB, thread_idx); // (THR_N,THR_K,k) -> (m,k) +``` The following code creates predicate tensors corresponding to `tAgA` and `tBgB`. @@ -84,166 +189,35 @@ They will be computed once in the prologue. and will be used to mask off instructions in the inner loop. ```c++ -Tensor tApA = make_tensor(make_shape (size<0>(tAgA), size<1>(tAgA)), +Tensor tApA = make_tensor(make_shape (size<0>(tAcA), size<1>(tAcA)), make_stride( Int<1>{}, Int<0>{})); -Tensor tBpB = make_tensor(make_shape (size<0>(tBgB), size<1>(tBgB)), +Tensor tBpB = make_tensor(make_shape (size<0>(tBcB), size<1>(tBcB)), make_stride( Int<1>{}, Int<0>{})); ``` -We're only thread-parallelizing over the leftmost (row) dimension, -so we only need to predicate over the leftmost dimension. -Thus, we can make the rightmost (column) stride zero, -since we will never actually address the rightmost dimension. - -The following code creates "two-dimensional identity tensors" -that map coordinates (m,k) -> (m,k) -for the tile of data within the thread block. +Here, we make a few assumptions: we're only interested in predicates for one tile of data at a time and we're only interested in predicates for the m- and n-modes and will handle the k-mode predicates differently. +The m- and n- predicates will be considered constant across every tile and will be reused in every iteration of the mainloop. +Thus, we only store the predicates for the m- and n-modes and broadcast them across the k-mode. +When populating the tensors, we carry the same assumption through: ```c++ -Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -``` - -The following lines then tile and partition -the two reference tensors -in exactly the same way the data were tiled and partitioned -into `tAsA` and `tBsB`. - -```c++ -Tensor tAcA = local_partition(cA, tA, thread_idx); -Tensor tBcB = local_partition(cB, tB, thread_idx); -``` - -Tiling and partitioning affect the offset and domain, -but not the codomain of the tensors, -so we're left with tensors that map `(thr_m,thr_k) -> (m,k)` -where `(thr_m,thr_k)` is this particular thread's subtensor of the tile -and `(m,k)` is the original codomain: a coordinate into the original tile. - -The unrolled loops in the code below then compare -the m- and n-coordinates of those tensors with our known maximums -to mask off elements we are not allowed to access. - -```c++ -Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) -Tensor tAcA = local_partition(cA, tA, thread_idx); - -Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) -Tensor tBcB = local_partition(cB, tB, thread_idx); - -// Populate +// Populate the m- and n-predicates CUTE_UNROLL for (int m = 0; m < size<0>(tApA); ++m) { - tApA(m,0) = get<0>(tAcA(m,0)) < m_max_coord; + tApA(m,0) = elem_less(get<0>(tAcA(m,0,0)), shape<0>(mA)); // Compare the m-coordinate } CUTE_UNROLL for (int n = 0; n < size<0>(tBpB); ++n) { - tBpB(n,0) = get<0>(tBcB(n,0)) < n_max_coord; + tBpB(n,0) = elem_less(get<0>(tBcB(n,0,0)), shape<0>(mB)); // Compare the n-coordinate } ``` -Those last `for` loops fill in the two predicate tensors. -In this case, we only need to predicate over the leftmost dimension, -so we only address `(m,0)` resp. `(n,0)`. +and only compare the m- and n-coordinates of the 0th k-tile and 0th k-block. The stride-0 broadcasting mode still allows us to treat this data as a predicate tensor for each and every element of the tile to be loaded. -We can then use the predicate tensors in `copy_if` -to copy only the elements for which the corresponding -predicate tensor elements are nonzero. +Finally, we can then use the predicate tensors in `copy_if` to copy only the elements for which the corresponding predicate tensor elements are `true`. ```c++ -// Prefetch k_tile=0, gate these on k_residue as well -CUTE_UNROLL -for (int k = 0; k < size<1>(tAsA); ++k) { - if (get<1>(tAcA(0,k)) >= -k_residue) { // some other condition on the column index - copy_if(tApA, tAgA(_,k,0), tAsA(_,k,0)); - } -} - -CUTE_UNROLL -for (int k = 0; k < size<1>(tBsB); ++k) { - if (get<1>(tBcB(0,k)) >= -k_residue) { // some other condition on the column index - copy_if(tBpB, tBgB(_,k,0), tBsB(_,k,0)); - } -} -``` - -Here are some advantages of this "reference tensor" approach. - -1. It doesn't depend on the layout/strides of the tensor - being predicated, just the logical bounds being imposed. - -2. The partitioning stage can be anything. - -3. It naturally extends to any-dimensional predication. - -4. It's a natural generalization of a typical CUDA 1-D - parallel vector access pattern, - which computes an access index `k` - (e.g., as `blockDim.x * blockIdx.x + threadIdx.x`) - and then predicates access to the vector's `k`-th element - on whether `k` is in bounds. - -As an example of (3), the epilogue predication does exactly the same thing, - -```c++ -// Repeat with a tensor of coordinates for predication -Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); -Tensor tCcC = thr_mma.partition_C(cC); - -const bool isBetaZero = (beta == 0); - -CUTE_UNROLL -for (int i = 0; i < size(tCrC); ++i) { - if (elem_less(tCcC(i), make_coord(m_max_coord,n_max_coord))) { - tCgC(i) = isBetaZero ? alpha * tCrC(i) : alpha * tCrC(i) + beta * tCgC(i); - } -} -``` - -but with the mma responsible for the tiling/partitioning `tCcC` -so that the reference subtensor matches the accumulator's subtensor. -Then, the reference subtensor is predicated against the `if` bounds -(in both m- and n-coordinates) inside the `for` loop. - -Another way to explain this is that we don't modify the tiles -to give you the "right" extents so that you never overrun. -Instead, we let you query the original coordinate -to see if that coordinate overruns. -This avoids all branching and variable/dynamic loop bounds -(thus maintaining load balance and synchronicity, -both very important in-kernel) in favor of predication. -It's also general enough to extend to all ranks, -all layouts of threads and data, -and all tiling/partitioning patterns. - -## Copyright - -Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -SPDX-License-Identifier: BSD-3-Clause - -``` - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// Copy a k_tile from global memory to shared memory +copy_if(tApA, tAgA(_,_,k_tile), tAsA); +copy_if(tBpB, tBgB(_,_,k_tile), tBsB); ``` diff --git a/media/docs/pythonDSL/cute_dsl.rst b/media/docs/pythonDSL/cute_dsl.rst index 71fa4f7f..108837a0 100644 --- a/media/docs/pythonDSL/cute_dsl.rst +++ b/media/docs/pythonDSL/cute_dsl.rst @@ -6,12 +6,12 @@ CuTe DSL .. toctree:: :maxdepth: 1 - DSL Introduction - DSL Code Generation - DSL Control Flow - DSL JIT Argument Generation - DSL JIT Argument: Layouts - DSL JIT Caching + Introduction + Code Generation + Control Flow + JIT Argument Generation + JIT Argument: Layouts + JIT Caching Integration with Frameworks Debugging with the DSL Autotuning with the DSL diff --git a/media/docs/pythonDSL/cute_dsl_general/autotuning_gemm.rst b/media/docs/pythonDSL/cute_dsl_general/autotuning_gemm.rst index db76c8a7..64235811 100644 --- a/media/docs/pythonDSL/cute_dsl_general/autotuning_gemm.rst +++ b/media/docs/pythonDSL/cute_dsl_general/autotuning_gemm.rst @@ -3,10 +3,6 @@ Guidance for Auto-Tuning ============================= -.. contents:: Table of Contents - :depth: 2 - :local: - Numerous GEMM kernel code examples are offered within our codebase. When integrating these kernels into frameworks, auto-tuning becomes essential for achieving optimal performance. This involves selecting the appropriate diff --git a/media/docs/pythonDSL/cute_dsl_general/debugging.rst b/media/docs/pythonDSL/cute_dsl_general/debugging.rst index 649aa608..6302100b 100644 --- a/media/docs/pythonDSL/cute_dsl_general/debugging.rst +++ b/media/docs/pythonDSL/cute_dsl_general/debugging.rst @@ -3,10 +3,6 @@ Debugging ========= -.. contents:: Table of Contents - :depth: 2 - :local: - This page provides an overview of debugging techniques and tools for CuTe DSL programs. diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst index b4b463d4..67180c4d 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst @@ -6,10 +6,6 @@ End-to-End Code Generation ========================== -.. contents:: - :depth: 2 - :local: - 1. Techniques for Turning Python into |IR| ------------------------------------------ diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst index a16c79c3..18fd2528 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_control_flow.rst @@ -4,11 +4,8 @@ .. |DSL| replace:: CuTe DSL .. |Constexpr| replace:: **Constexpr** (compile-time Python value) -|DSL| Control Flow +Control Flow ================== -.. contents:: - :depth: 2 - :local: Overview diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_dynamic_layout.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_dynamic_layout.rst index 9c5cca7d..e8dd7701 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_dynamic_layout.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_dynamic_layout.rst @@ -3,10 +3,6 @@ .. |SLAY| replace:: static layout .. |DLAY| replace:: dynamic layout -.. contents:: Table of Contents - :depth: 2 - :local: - Static vs Dynamic layouts ========================= diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst index ca409771..af487563 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst @@ -4,12 +4,9 @@ .. |DSL| replace:: CuTe DSL -|DSL| +Introduction ====================== -.. contents:: Table of Contents - :depth: 2 - :local: Overview -------- diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst index a7c46003..18970012 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_arg_generation.rst @@ -2,12 +2,9 @@ .. |DSL| replace:: CuTe DSL .. |CUSTOM_TYPES| replace:: customized types -|DSL| JIT Function Argument Generation +JIT Function Argument Generation ======================================= -.. contents:: Table of Contents - :depth: 2 - :local: In a nutshell -------------- @@ -39,7 +36,7 @@ By default, |DSL| assumes dynamic arguments and tries to infer the argument type import cutlass.cute as cute @cute.jit - def foo(x: cutlass.Int32, y: cute.Constexpr): + def foo(x: cutlass.Int32, y: cutlass.Constexpr): print("x = ", x) # Prints x = ? print("y = ", y) # Prints y = 2 cute.printf("x: {}", x) # Prints x: 2 diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst index 30d07377..ecaea52b 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_jit_caching.rst @@ -3,11 +3,9 @@ .. _JIT_Caching: -|DSL| JIT Caching +JIT Caching ==================== -.. contents:: Table of Contents - :depth: 2 - :local: + Zero Compile and JIT Executor ----------------------------- diff --git a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst index 5abba902..269c6602 100644 --- a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst +++ b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst @@ -4,10 +4,6 @@ Integration with Frameworks ============================= -.. contents:: Table of Contents - :depth: 2 - :local: - In order to facilitate the integration of CUTLASS Python with popular frameworks, we leverage the `DLPack protocol `_ and transform tensors originating from these frameworks to CuTe tensors. The present page documents the conventions, the API available to the @@ -257,8 +253,7 @@ layouts. The full signature of ``mark_compact_shape_dynamic`` is as follows: The ``mode`` parameter determines which shape dimension becomes dynamic. After calling this function, the specific shape dimension given by ``mode`` is marked as dynamic immediately. The stride will be -updated accordingly but this process is delayed until the C ABI of the tensor is constructed. -For modes that have a shape of size 1, their stride are canonicalized to 0. +updated accordingly. For modes that have a shape of size 1, their stride are canonicalized to 0. The ``stride_order`` parameter specifies the ordering of strides in the tensor. It is consistent with ``torch.Tensor.dim_order()`` and defaults to ``None``. The parameter indicates the order of @@ -322,10 +317,6 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to import torch from cutlass.cute.runtime import from_dlpack - @cute.jit - def kernel(t: cute.Tensor): - pass - # (8,4,16,2):(2,16,64,1) a = torch.empty(16, 4, 8, 2).permute(2, 1, 0, 3) # (1,4,1,32,1):(4,1,4,4,4) => torch tensor when dimension has shape 1, its stride is degenerated to 1, @@ -337,14 +328,12 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to t0 = from_dlpack(a).mark_compact_shape_dynamic( mode=0, divisibility=2 ) - kernel(t0) # (?{div=2},4,16,2):(2,?{div=4},?{div=16},1) print(t0) t1 = from_dlpack(a).mark_compact_shape_dynamic( mode=1, divisibility=2 ) - kernel(t1) # (8,?{div=2},16,2):(2,16,?{div=32},1) print(t1) @@ -353,21 +342,18 @@ The following example demonstrates how to use ``mark_compact_shape_dynamic`` to ).mark_compact_shape_dynamic( mode=3, divisibility=2 ) - kernel(t2) # (8,?{div=2},16,?{div=2}):(?{div=2},?{div=16},?{div=32},1) print(t2) t3 = from_dlpack(b).mark_compact_shape_dynamic( mode=2, divisibility=1, stride_order=(3, 0, 2, 4, 1) ) - kernel(t3) # (1,4,?,32,1):(0,1,4,?{div=4},0) print(t3) t4 = from_dlpack(b).mark_compact_shape_dynamic( mode=2, divisibility=1, stride_order=(2, 3, 4, 0, 1) ) - kernel(t4) # (1,4,?,32,1):(0,1,128,4,0) print(t4) diff --git a/media/docs/pythonDSL/faqs.rst b/media/docs/pythonDSL/faqs.rst index e8cce741..90348d8d 100644 --- a/media/docs/pythonDSL/faqs.rst +++ b/media/docs/pythonDSL/faqs.rst @@ -124,7 +124,8 @@ Technical License --------------------- -**Q:What is the license for CuTe DSL and the associated GitHub samples?** +**What is the license for CuTe DSL and the associated GitHub samples?** + CuTe DSL components available `on Github `__ and via the nvidia-cutlass-dsl Python pip wheel are released under the `"NVIDIA Software End User License Agreement (EULA)" `__. Because the pip package includes a compiler that shares several components with the CUDA Toolkit, diff --git a/media/docs/pythonDSL/limitations.rst b/media/docs/pythonDSL/limitations.rst index 73d23a25..7be5b051 100644 --- a/media/docs/pythonDSL/limitations.rst +++ b/media/docs/pythonDSL/limitations.rst @@ -3,9 +3,6 @@ Limitations ==================== -.. contents:: - :depth: 2 - :local: Overview --------------------- diff --git a/media/docs/pythonDSL/overview.rst b/media/docs/pythonDSL/overview.rst index 07abfb09..fbd3abd8 100644 --- a/media/docs/pythonDSL/overview.rst +++ b/media/docs/pythonDSL/overview.rst @@ -42,7 +42,7 @@ Core CuTe DSL Abstractions - **Atoms** – Represent fundamental hardware operations like matrix multiply-accumulate (MMA) or memory copy. - **Tiled Operations** – Define how atoms are applied across thread blocks and warps (e.g., ``TiledMma``, ``TiledCopy``). -For more on CuTe abstractions, refer to the `CuTe C++ library documentation `__. +For more on CuTe abstractions, refer to the `CuTe C++ library documentation `__. **Pythonic Kernel Expression** diff --git a/media/docs/pythonDSL/quick_start.rst b/media/docs/pythonDSL/quick_start.rst index 0c7fb505..18569b17 100644 --- a/media/docs/pythonDSL/quick_start.rst +++ b/media/docs/pythonDSL/quick_start.rst @@ -29,3 +29,12 @@ To run examples and begin development, we recommend installing: .. code-block:: bash pip install torch jupyter + +Recommended Python environment variables for jupyter notebooks +-------------------------------------------------------------- + +We recommend setting the following environment variable when running jupyter notebooks. + +.. code-block:: bash + + export PYTHONUNBUFFERED=1 diff --git a/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py index d515113b..8109d3c2 100644 --- a/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py @@ -578,7 +578,7 @@ class ArithValue(ir.Value): # Unary operators def __invert__(self, *, loc=None, ip=None) -> "ArithValue": - return arith.xori(self, arith.const(self.type, -1)) + return arith.xori(self, arith.constant(self.type, -1)) # Bitwise operations @_dispatch_to_rhs_r_op diff --git a/python/CuTeDSL/base_dsl/ast_helpers.py b/python/CuTeDSL/base_dsl/ast_helpers.py index e8796cff..cc5cadb6 100644 --- a/python/CuTeDSL/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/base_dsl/ast_helpers.py @@ -95,7 +95,7 @@ class Executor: unroll=bool, unroll_full=int, ): - log().info("start [%s] stop [%s] step [%s]", start, stop, step) + log().debug("start [%s] stop [%s] step [%s]", start, stop, step) return self._loop_execute_range_dynamic( func, start, @@ -117,7 +117,7 @@ class Executor: used_args: list, iter_args: list, ): - log().info("start [%s] stop [%s] step [%s]", start, stop, step) + log().debug("start [%s] stop [%s] step [%s]", start, stop, step) loop_results = iter_args log().debug("iter_args [%s]", iter_args) for i in range(start, stop, step): @@ -374,7 +374,7 @@ def loop_selector( unroll_full=False, constexpr=None, ): - log().info( + log().debug( "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] constexpr [%s]", start, stop, @@ -415,7 +415,7 @@ def loop_selector( def if_selector(pred, used_args=[], yield_args=[]): - log().info("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args) + log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args) # Handle Numeric types here? from .typing import Numeric diff --git a/python/CuTeDSL/base_dsl/ast_preprocessor.py b/python/CuTeDSL/base_dsl/ast_preprocessor.py index e165c1db..f1e1c635 100644 --- a/python/CuTeDSL/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -248,39 +248,55 @@ class DSLPreprocessor(ast.NodeTransformer): # Step 3. Return the transformed tree return combined_body - def check_early_exit(self, tree): + def check_early_exit(self, tree, kind): """ Checks if a given region or scope in the provided Python code has early exits. """ class EarlyExitChecker(ast.NodeVisitor): - def __init__(self): + def __init__(self, kind): self.has_early_exit = False self.early_exit_node = None self.early_exit_type = None + self.kind = kind + self.loop_nest_level = 0 + # Early exit is not allowed in any level of dynamic control flow def visit_Return(self, node): self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "return" - def visit_Break(self, node): - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "break" - - def visit_Continue(self, node): - self.has_early_exit = True - self.early_exit_node = node - self.early_exit_type = "continue" - def visit_Raise(self, node): self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "raise" - checker = EarlyExitChecker() - checker.visit(tree) + def visit_Break(self, node): + # For break/continue in inner loops, we don't consider it as early exit + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "break" + + def visit_Continue(self, node): + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "continue" + + def visit_For(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + def visit_While(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + checker = EarlyExitChecker(kind) + checker.generic_visit(tree) if not checker.has_early_exit: return raise DSLAstPreprocessorError( @@ -591,7 +607,7 @@ class DSLPreprocessor(ast.NodeTransformer): if self.is_supported_range_call(node): constexpr_val = self.get_loop_constexpr(node) # Check for early exit and raise exception - self.check_early_exit(node) + self.check_early_exit(node, "for") start, stop, step = self.extract_range_args(node.iter) unroll, unroll_full = self.extract_unroll_args(node.iter) used_args, iter_args, flat_args = self.analyze_region_variables( @@ -659,37 +675,42 @@ class DSLPreprocessor(ast.NodeTransformer): snippet=ast.unparse(node), ) - test = ast.BoolOp( - op=ast.And(), - values=[ - ast.Compare( - left=ast.Call( - func=ast.Name(id="type", ctx=ast.Load()), - args=[node.values[0]], - keywords=[], + def short_circuit_eval(value, short_circuit_value): + return ast.BoolOp( + op=ast.And(), + values=[ + ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[value], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], ), - ops=[ast.Eq()], - comparators=[ast.Name(id="bool", ctx=ast.Load())], - ), - ast.Compare( - left=node.values[0], - ops=[ast.Eq()], - comparators=[short_circuit_value], - ), - ], - ) - return ast.copy_location( - ast.IfExp( + ast.Compare( + left=value, + ops=[ast.Eq()], + comparators=[short_circuit_value], + ), + ], + ) + + lhs = node.values[0] + + for i in range(1, len(node.values)): + test = short_circuit_eval(lhs, short_circuit_value) + lhs = ast.IfExp( test=test, - body=node.values[0], + body=lhs, orelse=ast.Call( func=helper_func, - args=node.values, + args=[lhs, node.values[i]], keywords=[], ), - ), - node, - ) + ) + + return ast.copy_location(lhs, node) def visit_UnaryOp(self, node): # Visit child nodes first @@ -916,7 +937,7 @@ class DSLPreprocessor(ast.NodeTransformer): return node # Check for early exit and raise exception - self.check_early_exit(node) + self.check_early_exit(node, "while") used_args, yield_args, flat_args = self.analyze_region_variables( node, active_symbols @@ -1021,7 +1042,7 @@ class DSLPreprocessor(ast.NodeTransformer): return node # Check for early exit and raise exception - self.check_early_exit(node) + self.check_early_exit(node, "if") used_args, yield_args, flat_args = self.analyze_region_variables( node, active_symbols diff --git a/python/CuTeDSL/base_dsl/dsl.py b/python/CuTeDSL/base_dsl/dsl.py index 619ed4c8..4870cdae 100644 --- a/python/CuTeDSL/base_dsl/dsl.py +++ b/python/CuTeDSL/base_dsl/dsl.py @@ -566,7 +566,9 @@ class BaseDSL: log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec) # Implicit cast to NumericMeta - if isinstance(arg_spec, t.NumericMeta): + if isinstance(arg_spec, t.NumericMeta) and not isinstance( + arg, arg_spec + ): arg = t.cast(arg, arg_spec) ir_arg, iv_block_args = ( @@ -589,15 +591,17 @@ class BaseDSL: self.log_additions(ir_arg) ir_args.extend(ir_arg) - return ir_args + return ir_args, iv_block_args fop_args = list(fop.regions[0].blocks[0].arguments) - ir_args = gen_exec_args(args, args_spec.args, args_spec.annotations, fop_args) - ir_kwargs = gen_exec_args( + ir_args, iv_block_args = gen_exec_args( + args, args_spec.args, args_spec.annotations, fop_args + ) + ir_kwargs, _ = gen_exec_args( [kwargs[arg] for arg in args_spec.kwonlyargs], args_spec.kwonlyargs, args_spec.annotations, - fop_args[len(ir_args) :], + fop_args[iv_block_args:], ) ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)} @@ -716,8 +720,10 @@ class BaseDSL: assert len(args) == len(args_spec.args) and len(kwargs) == len( args_spec.kwonlyargs - ), f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " - f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" + ), ( + f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " + f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" + ) jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] default_attr = ir.DictAttr.get({}) @@ -729,7 +735,7 @@ class BaseDSL: log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty) # Implicitly convert into Numeric type if possible - if isinstance(spec_ty, t.NumericMeta): + if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty): arg = t.cast(arg, spec_ty) # Type safety check diff --git a/python/CuTeDSL/base_dsl/jit_executor.py b/python/CuTeDSL/base_dsl/jit_executor.py index 2c997be3..3ed9282b 100644 --- a/python/CuTeDSL/base_dsl/jit_executor.py +++ b/python/CuTeDSL/base_dsl/jit_executor.py @@ -141,33 +141,60 @@ class JitExecutor: to get rid of mlir context. """ + # Process positional arguments with defaults + rectified_args = list(args) + if args_spec.defaults and len(args) < len(args_spec.args): + rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) + for k, v in kwargs.items(): + if k in args_spec.args: + idx = args_spec.args.index(k) + if idx < len(rectified_args): + rectified_args[idx] = v + else: + rectified_args.append(v) + + # Process keyword arguments + rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} + if args_spec.kwonlydefaults and len(rectified_kwargs) < len( + args_spec.kwonlyargs + ): + rectified_kwargs.update(args_spec.kwonlydefaults) + # args/kwargs must match arg_specs - # No canonicalization of args/kwargs to avoid extra latency - if len(args) != len(args_spec.args) or len(kwargs) != len(args_spec.kwonlyargs): + if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( + args_spec.kwonlyargs + ): raise DSLRuntimeError( "input args/kwargs length does not match runtime function signature!", context={ - "input args length": len(args), - "input kwargs length": len(kwargs), + "input args length": len(rectified_args), + "input kwargs length": len(rectified_kwargs), "function signature args length": len(args_spec.args), "function signature kwonlyargs length": len(args_spec.kwonlyargs), }, ) exe_args = [] - input_args = [*args, *kwargs.values()] - input_arg_names = [*args_spec.args, *args_spec.kwonlyargs] - for i, arg in enumerate(input_args): - arg_type = args_spec.annotations.get(input_arg_names[i], None) + input_args = rectified_args + list(rectified_kwargs.values()) + input_arg_names = args_spec.args + args_spec.kwonlyargs + for arg, arg_name in zip(input_args, input_arg_names): + # short-cut for args already converted + if hasattr(arg, "__c_pointers__"): + exe_args.extend(arg.__c_pointers__()) + continue + + arg_type = args_spec.annotations.get(arg_name, None) # Implicit cast to NumericMeta if isinstance(arg_type, t.NumericMeta): arg = t.cast(arg, arg_type) + else: + # If not any known type, try registered adapter to do the conversion + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) - # If not any known type, try registered adapter to do the conversion - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - adapted_arg = adapter(arg) if adapter else arg - exe_args.extend(get_c_pointers(adapted_arg)) + exe_args.extend(get_c_pointers(arg)) return exe_args diff --git a/python/CuTeDSL/base_dsl/runtime/cuda.py b/python/CuTeDSL/base_dsl/runtime/cuda.py index c4f88b58..278a5118 100644 --- a/python/CuTeDSL/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/base_dsl/runtime/cuda.py @@ -457,7 +457,7 @@ class StreamAdapter: def __init__(self, arg): self._arg = arg - self._c_pointer = ctypes.cast(self._arg.getPtr(), ctypes.c_void_p) + self._c_pointer = self._arg.getPtr() def __new_from_mlir_values__(self, values): assert len(values) == 1 diff --git a/python/CuTeDSL/base_dsl/typing.py b/python/CuTeDSL/base_dsl/typing.py index 7fc2b4d7..6554db61 100644 --- a/python/CuTeDSL/base_dsl/typing.py +++ b/python/CuTeDSL/base_dsl/typing.py @@ -629,7 +629,7 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False): b_type = b.dtype # Early return for same types (except when they're bools that need promotion) - if a_type == b_type and not (promote_bool and a_type.width == 1): + if a_type == b_type and not (promote_bool and a_type is Boolean): return a, b, a_type # Handle floating point promotions @@ -1315,10 +1315,7 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True) def __invert__(self, *, loc=None, ip=None): res_type = type(self) - # Create a constant of -1 (all bits set to 1) of the same type as value - all_ones = arith.constant(res_type.mlir_type, -1) - # XOR with -1 gives us bitwise NOT - return res_type(arith.xori(self.ir_value(), all_ones, loc=loc, ip=ip)) + return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip)) def __lshift__(self, other, *, loc=None, ip=None): return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip) @@ -1457,18 +1454,14 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. - Converted using Python's bool() function - Example: Boolean(1) -> True, Boolean(0) -> False - 2. Boolean: - - Direct value assignment - - Example: Boolean(Boolean(True)) -> True + 2. Numeric: + - Uses the Numeric.value to construct Boolean recursively - 3. Numeric: - - Uses the __dsl_bool__ method of the Numeric type - - 4. MLIR Value with IntegerType: + 3. MLIR Value with IntegerType: - If width is 1: Direct assignment - Otherwise: Compares with 0 using arith.cmpi - 5. MLIR Value with FloatType: + 4. MLIR Value with FloatType: - Compares with 0.0 using arith.cmpf - Uses unordered comparison to handle NaN values """ @@ -1479,19 +1472,35 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. value = None if isinstance(a, (bool, int, float)): value = bool(a) - elif isinstance(a, Boolean): - value = a.value elif isinstance(a, Numeric): - value = a.__dsl_bool__(loc=loc, ip=ip) + Boolean.__init__(self, a.value, loc=loc, ip=ip) + return elif isinstance(a, ArithValue): if a.type == T.bool(): value = a else: - value = a != arith_helper.const(0, a.type) - + value = a != arith_helper.const(0, a.type, loc=loc, ip=ip) if value is None: raise DSLRuntimeError(f"Cannot convert {a} to Boolean") super().__init__(value, loc=loc, ip=ip) + self._value_int8 = None + + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean. + When we need to store Boolean tensor element, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + if self._value_int8 is not None: + return self._value_int8 + self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value() + return self._value_int8 def __neg__(self, *, loc=None, ip=None): """Negation operator is not supported for boolean type. diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 6af262cd..3f67a411 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -37,6 +37,7 @@ from cutlass.cutlass_dsl import ( ) from cutlass._mlir import ir +from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results from cutlass._mlir.dialects import cute as _cute_ir from cutlass._mlir.dialects.cute import ( ScaledBasis as _ScaledBasis, @@ -962,6 +963,9 @@ class _Pointer(Pointer): # Cut off the MLIR type's string for making pretty_str more concise return self.type.__str__()[6:] + def __get_mlir_types__(self): + return [self.value.type] + def __extract_mlir_values__(self): return [self.value] @@ -979,7 +983,7 @@ class _Pointer(Pointer): @property @lru_cache_ir() - def value_type(self) -> Type[Numeric]: + def dtype(self) -> Type[Numeric]: return Numeric.from_mlir_type(self.value.type.value_type) @property @@ -993,7 +997,7 @@ class _Pointer(Pointer): @property @lru_cache_ir() def memspace(self) -> AddressSpace: - return self.type.address_space + return AddressSpace(self.type.address_space) # Make it behave as if it inherited from ir.Value @property @@ -1015,7 +1019,7 @@ class _Pointer(Pointer): :return: The LLVM pointer representation :rtype: ir.Value """ - llvm_ptr_ty = llvm.PointerType.get(self.type.address_space) + llvm_ptr_ty = llvm.PointerType.get(self.memspace.value) return builtin.unrealized_conversion_cast( [llvm_ptr_ty], [self.value], loc=loc, ip=ip ) @@ -1034,10 +1038,7 @@ class _Pointer(Pointer): @dsl_user_op def toint(self, *, loc=None, ip=None): - if self.type.address_space in ( - _cute_ir.AddressSpace.gmem, - _cute_ir.AddressSpace.generic, - ): + if self.memspace in (AddressSpace.gmem, AddressSpace.generic): res_type = Int64 else: res_type = Int32 @@ -1067,25 +1068,26 @@ class _Pointer(Pointer): raise ValueError("Alignment must be a power of 2") assert isinstance(self.type, _cute_ir.PtrType) - if self.type.address_space is AddressSpace.tmem: + if self.memspace is AddressSpace.tmem: raise ValueError("aligning a TMEM pointer is not supported") if min_align <= self.alignment: return self - else: - # Convert pointer to integer - address_int = self.toint(loc=loc, ip=ip) - # Align the address - aligned_address = (address_int + min_align - 1) & ~(min_align - 1) - # Create and return the aligned pointer - return make_ptr( - Numeric.from_mlir_type(self.type.value_type), - aligned_address, - self.type.address_space, - assumed_align=min_align, - loc=loc, - ip=ip, - ) + + dtype = Numeric.from_mlir_type(self.type.value_type) + # Convert pointer to integer + address_int = self.toint(loc=loc, ip=ip) + # Align the address + aligned_address = (address_int + min_align - 1) & ~(min_align - 1) + + return make_ptr( + dtype, + aligned_address, + self.memspace, + assumed_align=min_align, + loc=loc, + ip=ip, + ) @ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) @@ -1138,8 +1140,34 @@ class _Tensor(Tensor): self._dtype = dtype if isinstance(value, ir.Value): self.value = value + elif isinstance(value, _Tensor): + self.value = value.value else: - raise TypeError(f"Expected ir.Value, got {type(value)}") + raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}") + + # Set iterator + iter_val = _cute_ir.get_iter(self.value) + if isinstance(iter_val, Pointer): + self._iterator = iter_val + elif isinstance(iter_val.type, _cute_ir.IntTupleType): + self._iterator = _unpack_x_tuple(iter_val) + elif isinstance(iter_val, ir.Value): + # Example: SMEM descriptor iterator, not well supported today + self._iterator = iter_val + else: + raise TypeError(f"unsupported iterator type, got {type(iter_val)}") + + # Set dtype + if self._dtype is None: + if is_int_tuple(self.iterator): + self._dtype = IntTuple + elif isinstance(self.iterator, Pointer): + self._dtype = self.iterator.value_type + elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): + # SmemDescViewType do not need dtype + self._dtype = None + else: + raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") def __str__(self): return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" @@ -1157,7 +1185,7 @@ class _Tensor(Tensor): ), f"Expected _Tensor or ir.Value, but got {type(values[0])}" return _Tensor( values[0] if isinstance(values[0], ir.Value) else values[0].value, - self._dtype, + dtype=self.element_type, ) # Cheat to let `Type(_Tensor())` to return cute.Tensor @@ -1252,9 +1280,6 @@ class _Tensor(Tensor): return self.element_type(data_val) def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): - if data.dtype is self.element_type: - return data.ir_value(loc=loc, ip=ip) - orig_dtype = data.dtype # Implicit upcast to wider type if ( @@ -1269,11 +1294,11 @@ class _Tensor(Tensor): f"to Tensor with element type {self.element_type}" ) - val = data.ir_value(loc=loc, ip=ip) - if isinstance(data.dtype, (Int8, Boolean)) and (self.element_type is Boolean): - zero = Int8(0).ir_value(loc=loc, ip=ip) - val = arith.cmpi(arith.CmpIPredicate.ne, val, zero, loc=loc, ip=ip) - + if data.dtype is Boolean and self.element_type is Boolean: + # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory + val = data.ir_value_int8() + else: + val = data.ir_value() return val @dsl_user_op @@ -1340,7 +1365,7 @@ class _Tensor(Tensor): # Implicit upcast to wider type val = self._cvt_to_dest(data, loc=loc, ip=ip) - if val.type != self.element_type.mlir_type: + if val.type != self.type.value_type: raise ValueError( f"type mismatch, store {val.type} to {self.element_type}" ) @@ -1365,16 +1390,7 @@ class _Tensor(Tensor): @property def iterator(self) -> Union[Pointer, IntTuple]: - res = _cute_ir.get_iter(self.value) - if isinstance(res, Pointer): - return res - elif isinstance(res.type, _cute_ir.IntTupleType): - return _unpack_x_tuple(res) - elif isinstance(res, ir.Value): - # Example: SMEM descriptor iterator, not well supported today - return res - else: - raise TypeError(f"unsupported iterator type, got {type(res)}") + return self._iterator @property def layout(self) -> Layout: @@ -1405,12 +1421,7 @@ class _Tensor(Tensor): @property @lru_cache_ir() def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: - if is_integer(self.iterator) or isinstance(self.iterator, tuple): - return IntTuple - elif isinstance(self.iterator, Pointer): - return self.iterator.value_type - else: - raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") + return self._dtype @property @lru_cache_ir() @@ -1443,7 +1454,14 @@ class _Tensor(Tensor): self._check_can_load_store() res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) - + if self.element_type is Boolean: + assert ( + res_vect.type.element_type == T.i8() + ), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" + zeros = full_like(self, 0, Int8, loc=loc, ip=ip) + res_vect = arith.cmpi( + arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip + ) return TensorSSA(res_vect, self.shape, self.element_type) @dsl_user_op @@ -1532,9 +1550,7 @@ class _Tensor(Tensor): self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip) def _check_can_load_store(self): - if not isinstance( - self.type, _cute_ir.MemRefType - ) or not self.type.address_space in ( + if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in ( AddressSpace.rmem, AddressSpace.smem, AddressSpace.gmem, @@ -1734,10 +1750,6 @@ def printf(*args, loc=None, ip=None) -> None: arg0 = arg.value if isinstance(arg, Numeric) else arg if isinstance(arg0, ir.Value): - if isinstance(arg0.type, ir.FloatType) and (arg0.type != T.f32()): - raise TypeError( - f"cute.printf only supports 32-bit floating-point type, but got {arg0.type}" - ) return arg0 elif isinstance(arg0, bool): return const(arg0, Boolean) @@ -2212,11 +2224,13 @@ def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): shape = make_shape(2, 3, 4, 5) grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5) """ - if depth(input) == 0: + if depth(input) == 0 and is_integer(input): return (input,) if isinstance(input, tuple): return (*input[:begin], (input[begin:end]), *input[end:]) - return _cute_ir.group_modes(input.value, begin, end, loc=loc, ip=ip) + return _cute_ir.group_modes( + input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip + ) @overload @@ -2315,10 +2329,13 @@ def slice_(src, coord: Coord, *, loc=None, ip=None): else: return () + res_type = None if isinstance(src, Tensor): + res_type = src.element_type src = src.value coord_val = _pack_coord(coord, loc=loc, ip=ip) - return _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) + res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res @overload @@ -2751,7 +2768,8 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None): """Filter out zeros from a layout or tensor. This function removes zero-stride dimensions from a layout or tensor. - See Section 3.3 in the CuTe Whitepaper for more details on layout operations. + Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md + for more layout algebra operations. :param input: The input layout or tensor to filter :type input: Layout or Tensor @@ -2913,7 +2931,8 @@ def size( Computes the size (number of elements) in the domain of a layout or tensor. For layouts, this corresponds to the shape of the coordinate space. - See Section 3.2 in the CuTe Whitepaper for more details on layout domains. + See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md + for more details on layout domains. :param a: The input object whose size to compute :type a: IntTuple, Shape, Layout, ComposedLayout or Tensor @@ -3177,7 +3196,7 @@ def make_composed_layout( ) -> ComposedLayout: """Create a composed layout by composing an inner transformation with an outer layout. - As described in the CuTe whitepaper, a composed layout applies a sequence of transformations + A composed layout applies a sequence of transformations to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations are applied from right to left. @@ -3416,12 +3435,7 @@ def recast_ptr( value_type = ptr.type.value_type if dtype is None else dtype swizzle = swizzle_.type.attribute if swizzle_ is not None else None - res_ty = _cute_ir.PtrType.get( - value_type, - AddressSpace(ptr.type.address_space), - ptr.alignment, - swizzle, - ) + res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle) return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip) @@ -3438,8 +3452,15 @@ def make_ptr( if dtype is None or not isinstance(dtype, NumericMeta): raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}") + if not isinstance(mem_space, AddressSpace): + raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + + if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): + value = llvm.ptrtoint(T.i64(), value) + if not is_integer(value): raise TypeError(f"expects integer value, but got {type(value)}") + value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) bytes_per_elt = max(1, dtype.width // 8) if assumed_align is None: @@ -3450,13 +3471,11 @@ def make_ptr( f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa." ) - value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip) - ptr_ty = _cute_ir.PtrType.get( - T.i8() if dtype is None else dtype.mlir_type, mem_space, assumed_align - ) + data_ty = T.i8() if dtype is None else dtype.mlir_type + ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) @@ -3582,7 +3601,7 @@ def make_fragment( ) -> Tensor: if not issubclass(dtype, Numeric): raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") - elem_ty = dtype.mlir_type + elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() # Alignment for register memory is useless(?), pick-up large enough number # to allow .128 (> 16B) load store @@ -3691,16 +3710,12 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None): ) return make_fragment(new_layout, dtype, loc=loc, ip=ip) else: - if dtype is None: - ty = src.element_type.mlir_type - else: - ty = dtype.mlir_type + dtype = src.element_type if dtype is None else dtype + ty = dtype.mlir_type if dtype is not Boolean else T.i8() new_tensor = _cute_ir.make_fragment_like( src.value, elem_type=ty, loc=loc, ip=ip ) - return _Tensor( - new_tensor.value, dtype if dtype is not None else src.element_type - ) + return _Tensor(new_tensor.value, dtype) else: raise TypeError( f"src must be a Layout or ComposedLayout or tensor, got {type(src)}" @@ -3958,11 +3973,14 @@ def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor @dsl_user_op def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None if isinstance(target, _Tensor): + res_type = target.element_type target = target.value if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - return _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) + res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res @overload @@ -3973,11 +3991,14 @@ def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: @dsl_user_op def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None if isinstance(target, _Tensor): + res_type = target.element_type target = target.value if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - return _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) + res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res @overload @@ -3988,11 +4009,14 @@ def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: @dsl_user_op def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None if isinstance(target, _Tensor): + res_type = target.element_type target = target.value if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - return _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) + res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res @overload @@ -4003,11 +4027,14 @@ def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: . @dsl_user_op def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None if isinstance(target, _Tensor): + res_type = target.element_type target = target.value if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) - return _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) + res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res # @@ -4075,14 +4102,22 @@ def tile_to_shape( def local_partition( target: Tensor, tiler: Union[Layout, Shape], - index, + index: Union[int, Numeric], proj: XTuple = 1, *, loc=None, ip=None, ) -> Tensor: + if isinstance(index, cutlass_arith.ArithValue): + index_val = index + else: + index_val = index.ir_value() + if index_val.type.width > 32: + raise NotImplementedError( + f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" + ) return _cute_ir.local_partition( - input=target.value, tiler=dice(tiler, proj), index=index, loc=loc, ip=ip + input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip ) @@ -4332,6 +4367,8 @@ class MmaAtom(Atom): def make_fragment_A(self, input, *, loc=None, ip=None): # input could be memref/shape/layout for tmem based fragment if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_A(input, loc=loc, ip=ip) input = input.value if isinstance(input, tuple): input = _pack_shape(input, loc=loc, ip=ip) @@ -4343,9 +4380,12 @@ class MmaAtom(Atom): ip=ip, ) + @dsl_user_op def make_fragment_B(self, input, *, loc=None, ip=None): if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_B(input, loc=loc, ip=ip) input = input.value return _cute_ir.mma_make_fragment( _cute_ir.MmaOperand.B, @@ -5193,7 +5233,7 @@ def copy( src: Tensor, dst: Tensor, *, - pred: Tensor = None, + pred: Optional[Tensor] = None, loc=None, ip=None, **kwargs, @@ -5334,7 +5374,7 @@ class TensorSSA(cutlass_arith.ArithValue): other = as_numeric(other) # Promote types - lhs, rhs, res_type = _binary_op_type_promote(self, other, True) + lhs, rhs, res_type = _binary_op_type_promote(self, other) # Promote scalar to vector if not isinstance(rhs, TensorSSA): @@ -5827,6 +5867,28 @@ class TensorSSA(cutlass_arith.ArithValue): def ir_value(self, *, loc=None, ip=None): return self + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean tensor. + When we need to store Boolean tensor ssa, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + assert ( + self.element_type is Boolean + ), f"Only boolean type needs to be converted to int8, got {self.element_type}" + + if not hasattr(self, "_value_int8"): + self._value_int8 = arith.extsi( + T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip + ) + return self._value_int8 + def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): """ Perform reduce on selected modes with given predefined reduction op. diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/python/CuTeDSL/cutlass/cute/nvgpu/common.py index c93becad..87a01be9 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -84,6 +84,11 @@ class MmaUniversalOp(core.MmaOp): ) return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip)) + def _verify_fragment_A(self, input, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input, *, loc=None, ip=None): + pass class MmaUniversalTrait(core.Trait): pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index 096a4e12..b5d681cf 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, rank, depth +from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor from ...typing import ( Shape, Float8E5M2, @@ -34,6 +34,7 @@ from ...typing import ( Uint8, Int32, Numeric, + AddressSpace, ) @@ -212,6 +213,30 @@ class MmaOp(MmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + class MmaTrait(Trait): admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index d7fe3b3b..49df213b 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -16,8 +16,8 @@ import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape -from ...typing import Shape, Float16, BFloat16, Float32, Numeric +from ...core import MmaOp, Trait, _pack_shape, _Tensor +from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace @dataclass(frozen=True) @@ -73,6 +73,11 @@ class MmaF16BF16Op(MmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + pass class MmaF16BF16Trait(Trait): pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py index b3749574..ca9177f3 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -20,7 +20,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError -from ...core import MmaOp, Trait, _pack_shape, rank, depth +from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor from ...typing import ( Shape, Float16, @@ -30,6 +30,7 @@ from ...typing import ( Float8E5M2, Float8E4M3FN, Numeric, + AddressSpace, ) @@ -167,6 +168,30 @@ class MmaOp(MmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + class MmaTrait(Trait): admissible_fields = [Field.ACCUMULATE] diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index 47e67b88..ec0ccc5f 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -124,7 +124,7 @@ class _Pointer(Pointer): ) @property - def element_type(self) -> Type[Numeric]: + def dtype(self) -> Type[Numeric]: return self._dtype @property diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 90fb1fb2..2d81a9c7 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -35,20 +35,7 @@ from inspect import isclass def assert_(cond, msg=None): - if isinstance(cond, ir.Value): - if ir.VectorType.isinstance(cond.type): - assert ( - cond.type.element_type == T.bool() - ), f"only expects vector type with boolean elements, but got {cond.type}" - cond_val = vector.multi_reduction( - vector.CombiningKind.AND, cond, const(True), range(cond.type.rank) - ) - else: - cond_val = cond - else: - cond_val = const(cond, t.Boolean) - - cf.assert_(cond_val, msg if msg else "") + cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "") def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): diff --git a/python/CuTeDSL/cutlass/cute/typing.py b/python/CuTeDSL/cutlass/cute/typing.py index 48ac76c4..a13ebeaf 100644 --- a/python/CuTeDSL/cutlass/cute/typing.py +++ b/python/CuTeDSL/cutlass/cute/typing.py @@ -56,14 +56,23 @@ XTuple = Union[IntTuple, Shape, Stride, Coord, Tile] Tiler = Union[Shape, Layout, Tile] -class Pointer: +class Pointer(ABC): """ Abstract base class for CuTe jit function and runtime _Pointer """ - def __extract_mlir_values__(self): - # Doesn't matter just return a value - return [self] + @property + def value_type(self) -> Type[Numeric]: + return self.dtype + + @property + def dtype(self) -> Type[Numeric]: ... + + def __get_mlir_types__(self) -> List[ir.Type]: ... + + def __extract_mlir_values__(self) -> List[ir.Value]: ... + + def __new_from_mlir_values__(self, values) -> "Pointer": ... class Tensor(ABC): @@ -144,10 +153,13 @@ class Tensor(ABC): def store(self, data: "TensorSSA", *, loc=None, ip=None): ... - def mark_layout_dynamic(self, leading_dim: int|None = None) -> "Tensor": ... + def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ... def mark_compact_shape_dynamic( - self, mode: int, stride_order: tuple[int, ...]|None = None, divisibility: int = 1 + self, + mode: int, + stride_order: tuple[int, ...] | None = None, + divisibility: int = 1, ) -> "Tensor": ... @abstractmethod diff --git a/python/CuTeDSL/cutlass/utils/pipeline.py b/python/CuTeDSL/cutlass/utils/pipeline.py index a339a3e3..eb104278 100644 --- a/python/CuTeDSL/cutlass/utils/pipeline.py +++ b/python/CuTeDSL/cutlass/utils/pipeline.py @@ -30,6 +30,7 @@ class Agent(enum.Enum): """ Agent indicates what is participating in the pipeline synchronization. """ + # Arbitrary grouping of N threads Thread = enum.auto() # Same as AsyncThread, but includes all threads in the block @@ -42,6 +43,7 @@ class CooperativeGroup: """ CooperativeGroup contains size and alignment restrictions for an Agent. """ + def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): if agent is Agent.Thread: assert size > 0 @@ -76,6 +78,7 @@ class _PipelineOp(enum.Enum): """ PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. """ + # async-threads AsyncThread = enum.auto() # Blackwell (SM100a) MMA instruction @@ -140,12 +143,8 @@ class MbarrierArray(SyncObjectArray): "Error: Mbarrier tx count must be greater than 0 for TMA ops." ) - # Using a tensor to store mbarrier i64 ptrs - self.mbarrier_array = cute.make_fragment(cute.make_layout(num_stages), Int64) - for i in range(num_stages): - self.mbarrier_array[i] = _cute_ir.ptrtoint( - T.i64(), (self.barrier_storage + i).value - ) + # Store mbarrier base pointer + self.mbarrier_base = self.barrier_storage # Mbarrier initialization in constructor self.mbarrier_init() @@ -155,10 +154,11 @@ class MbarrierArray(SyncObjectArray): """ Initializes an array of mbarriers using warp 0. """ + def then_body(): for index in range(self.num_stages): cute.arch.mbarrier_init_arrive_cnt( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]), self.arrive_count + self.get_barrier(index), self.arrive_count ) warp_idx = cute.arch.warp_idx() @@ -166,7 +166,12 @@ class MbarrierArray(SyncObjectArray): if_generate(warp_idx == 0, then_body) - def arrive(self, index: int, dst: int): + def arrive( + self, + index: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ): """ Select the arrive corresponding to this MbarrierArray's PipelineOp :param index: Index of the mbarrier in the array to arrive on @@ -175,55 +180,53 @@ class MbarrierArray(SyncObjectArray): - For TCGen05Mma, dst serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs in the cluster with rank = 0, 1, and 3). - For AsyncThread, dst serves as a destination cta rank (e.g., 3 means threads will arrive on the mbarrier with rank = 3 in the cluster). :type dst: int | None + :param cta_group: CTA group for TCGen05Mma, defaults to None for other op types + :type cta_group: cute.nvgpu.tcgen05.CtaGroup, optional """ if self.op_type is _PipelineOp.AsyncThread: self.arrive_mbarrier(index, dst) elif self.op_type is _PipelineOp.TCGen05Mma: - self.arrive_tcgen05mma(index, dst) + assert ( + cta_group is not None + ), "Error: CTA group must be provided for TCGen05Mma." + self.arrive_tcgen05mma(index, dst, cta_group) elif self.op_type in [_PipelineOp.TmaLoad]: self.arrive_and_expect_tx(index, self.tx_count) else: - print(_get_pipeline_op(self.op_type)) - assert False, "Error: MbarrierArray is not supported for this PipelineOp." + assert False, f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." def arrive_mbarrier(self, index: int, dst_rank: int): if dst_rank is None: - cute.arch.mbarrier_arrive(_mbarrier_i64_to_ptr(self.mbarrier_array[index])) + cute.arch.mbarrier_arrive(self.get_barrier(index)) else: - cute.arch.mbarrier_arrive( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]), dst_rank - ) + cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) - def arrive_tcgen05mma(self, index: int, mask: int): + def arrive_tcgen05mma( + self, index: int, mask: int, cta_group: cute.nvgpu.tcgen05.CtaGroup + ): if mask is None: with cute.arch.elect_one(): - cute.nvgpu.tcgen05.commit( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]) - ) + cute.nvgpu.tcgen05.commit(self.get_barrier(index)) else: with cute.arch.elect_one(): cute.nvgpu.tcgen05.commit( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]), + self.get_barrier(index), mask, - cute.nvgpu.tcgen05.CtaGroup.TWO, + cta_group, ) def arrive_and_expect_tx(self, index: int, tx_count: int): with cute.arch.elect_one(): - cute.arch.mbarrier_init_tx_bytes( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]), tx_count - ) + cute.arch.mbarrier_init_tx_bytes(self.get_barrier(index), tx_count) def try_wait(self, index: int, phase: int): - return cute.arch.mbarrier_try_wait( - _mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase - ) + return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) def wait(self, index: int, phase: int): - cute.arch.mbarrier_wait(_mbarrier_i64_to_ptr(self.mbarrier_array[index]), phase) + cute.arch.mbarrier_wait(self.get_barrier(index), phase) def get_barrier(self, index: int) -> cute.Pointer: - return _mbarrier_i64_to_ptr(self.mbarrier_array[index]) + return self.mbarrier_base + index class TmaStoreFence(SyncObjectArray): @@ -390,6 +393,7 @@ class PipelineAsync: PipelineAsync is a generic pipeline class where both the producer and consumer are AsyncThreads. It also serves as a base class for specialized pipeline classes. """ + sync_object_array_full: SyncObjectArray sync_object_array_empty: SyncObjectArray num_stages: Int32 @@ -522,6 +526,7 @@ class PipelineTmaAsync(PipelineAsync): """ PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). """ + is_signalling_thread: bool @staticmethod @@ -628,7 +633,6 @@ class PipelineTmaAsync(PipelineAsync): ) self.sync_object_array_full.arrive(state.index, self.producer_mask) - def producer_commit(self, state: PipelineState): """ TMA producer commit is a NOP. The transaction barrier signals the commit upon completion of the TMA. @@ -646,12 +650,15 @@ class PipelineTmaAsync(PipelineAsync): ), ) + @dataclass(frozen=True) class PipelineTmaUmma(PipelineAsync): """ PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). """ + is_leader_cta: bool + cta_group: cute.nvgpu.tcgen05.CtaGroup @staticmethod def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): @@ -748,6 +755,12 @@ class PipelineTmaUmma(PipelineAsync): producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + consumer_mask = producer_mask pipeline_init_wait(cta_layout_vmnk) @@ -759,6 +772,15 @@ class PipelineTmaUmma(PipelineAsync): producer_mask, consumer_mask, is_leader_cta, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_array_empty.arrive( + state.index, self.consumer_mask, self.cta_group ) def producer_acquire( @@ -789,6 +811,8 @@ class PipelineUmmaAsync(PipelineAsync): PipelineTmaUmma is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). """ + cta_group: cute.nvgpu.tcgen05.CtaGroup + @staticmethod def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): """ @@ -858,6 +882,12 @@ class PipelineUmmaAsync(PipelineAsync): else: consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + pipeline_init_wait(cta_layout_vmnk) return PipelineUmmaAsync( @@ -866,6 +896,15 @@ class PipelineUmmaAsync(PipelineAsync): num_stages, producer_mask, consumer_mask, + cta_group, + ) + + def producer_commit(self, state: PipelineState): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_array_full.arrive( + state.index, self.producer_mask, self.cta_group ) def producer_tail(self, state: PipelineState): diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index 3e3a4020..a2d72698 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -185,7 +185,7 @@ class SmemAllocator: and isinstance(layout.inner, cute.Swizzle) ) and (swizzle is not None): raise TypeError( - f"iterator swizzle with swizzle layout is currently not supported" + f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." ) if isinstance(layout, int): diff --git a/python/CuTeDSL/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass_dsl/cutlass.py index 1e2f4d1c..3cbd6874 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -527,13 +527,13 @@ def pack_from_irvalue( """ Packs MLIR values into a list of mixed values. """ - log().info("===--- Values Pack (%d)", len(ir_values)) + log().debug("===--- Values Pack (%d)", len(ir_values)) for idx, packed in enumerate(ir_values): - log().info("[%d]: will-packed: %s", idx, ir_values) + log().debug("[%d]: will-packed: %s", idx, ir_values) for idx, unpacked in indices.items(): - log().info("[%d]: indices: %s", idx, unpacked) + log().debug("[%d]: indices: %s", idx, unpacked) for idx, c in enumerate(class_types): - log().info("[%d]: obj-types: %s", idx, type(c)) + log().debug("[%d]: obj-types: %s", idx, type(c)) mixed_values = [None] * len(indices) for idx, (start, length) in sorted(indices.items()): @@ -552,10 +552,10 @@ def pack_from_irvalue( except DSLRuntimeError as e: mixed_values[idx] = chunk[0] - log().info("------------------ ") + log().debug("------------------ ") for idx, packed in enumerate(mixed_values): - log().info("[%d]: packed: %s", idx, packed) - log().info("------------------ ") + log().debug("[%d]: packed: %s", idx, packed) + log().debug("------------------ ") return mixed_values @@ -571,9 +571,9 @@ def unpack_to_irvalue( class_types = [] current_offset = 0 - log().info("===--- Values UNPack (%d)", len(mixed_values)) + log().debug("===--- Values UNPack (%d)", len(mixed_values)) for idx, packed in enumerate(mixed_values): - log().info("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) + log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) for idx, item in enumerate(mixed_values): class_types.append(item) try: @@ -612,16 +612,16 @@ def unpack_to_irvalue( ), ) from e - log().info("------------------ ") + log().debug("------------------ ") for idx, unpacked in enumerate(unpacked_values): - log().info("[%d]: unpacked values: %s", idx, unpacked) + log().debug("[%d]: unpacked values: %s", idx, unpacked) for idx, unpacked in enumerate(ir_values): - log().info("[%d]: unpacked ir_values: %s", idx, unpacked) + log().debug("[%d]: unpacked ir_values: %s", idx, unpacked) for idx, unpacked in indices.items(): - log().info("[%d]: indices: %s", idx, unpacked) + log().debug("[%d]: indices: %s", idx, unpacked) for idx, unpacked in enumerate(class_types): - log().info("[%d]: initial-class-types: %s", idx, unpacked) - log().info("------------------ ") + log().debug("[%d]: initial-class-types: %s", idx, unpacked) + log().debug("------------------ ") return ir_values, unpacked_values, indices, class_types @@ -1302,7 +1302,6 @@ class WhileLoopContext: def __exit__(self, exc_type, exc_value, traceback): self.ipoint_op.__exit__(exc_type, exc_value, traceback) - return True @property def results(self): diff --git a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py index ba7b9d76..56262a76 100644 --- a/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py @@ -331,11 +331,7 @@ def _if_execute_dynamic( # Assume final result types match the dynamic yields result_types = [arg.type for arg in dyn_yield_ops] - pred_ = t.as_numeric(pred) - - if not isinstance(pred_, Boolean): - # Convert to Boolean through comparison - pred_ = pred_ == True + pred_ = Boolean(pred) try: if_op = scf.IfOp( diff --git a/python/cutlass_library/__init__.py b/python/cutlass_library/__init__.py index d164768c..3d95edb5 100644 --- a/python/cutlass_library/__init__.py +++ b/python/cutlass_library/__init__.py @@ -35,6 +35,7 @@ import sys from . import conv2d_operation from . import conv3d_operation +from . import emit_kernel_listing from . import gemm_operation if '-m' not in sys.argv: @@ -53,7 +54,7 @@ from . import trmm_operation from .library import * # Set up `source` to point to the path containing the CUTLASS source. -# Check first if the path cotains a `source` subdirectory -- this will +# Check first if the path contains a `source` subdirectory -- this will # be the case when the package has been installed via pip. Otherwise, # default to the root of CUTLASS. install_source_path = os.path.join(__path__[0], 'source') diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 70ba077e..243f5adb 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -253,7 +253,8 @@ def _getInstType(input_precision, accumulate_precision, math_instruction): return inst # TODO: Computes FLOps/Bytes for GEMM - revisit for conv -def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0): +def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1): + assert not (batch_count > 1 and num_groups > 1) # TODO: adjust for sparsity gmem_bytes = ( @@ -269,16 +270,15 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0): gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n flops += 2 * m * n - gmem_bytes *= batch_count - flops *= batch_count + multiplier = max(batch_count, num_groups) + gmem_bytes *= multiplier + flops *= multiplier return flops / gmem_bytes def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ): profiler_reference_computing = "--verification-providers=device --providers=cutlass" - - # beta values for L0 and L1 # TODO: randomize beta values for wider coverage beta_values = [0.5] @@ -303,15 +303,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ] sm100_mma_data_type_runtime_dtype = [ - 'gemm_f4_f4_f32_f32_f32', - 'gemm_f6_f6_f32_f32_f32', - 'gemm_f8_f8_f32_f32_f32', - ] - - sm100_mma_data_type_mergeable = [ - 'gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification - 'gemm_e2m1_e2m1_f32_f32_f32', - 'gemm_e3m2_e3m2_f32_f32_f32', + 'gemm.*f4_f4_f32_f32_f32', + 'gemm.*f6_f6_f32_f32_f32', + 'gemm.*f8_f8_f32_f32_f32', ] sm100_mma_cluster_size = [ @@ -327,9 +321,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ] # regex list must be in kernel procedural name order - mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" - mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" - sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" @@ -340,25 +331,15 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode # Block Scale Gemm # - block_scaled_data_type_base = [ + block_scaled_data_type = [ # runtime datatypes 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', - 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', ] - block_scaled_data_type_mergeable = [ - 'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', - 'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', - 'gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', - 'gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', - 'gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', - ] - - block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable - block_scaled_cluster_size = [ '4x4x1', '2x1x1', '0x0x1' # dynamic cluster @@ -366,27 +347,25 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode block_scaled_layouts = ['tnt'] # regex list must be in kernel procedural name order - mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" - mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" - if arch == "100a" or arch == "100f": + if arch in ["100a", "100f"]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ f"({sm100_mma_filter_regex_1sm_runtime})|" \ f"({sm100_mma_filter_regex_2sm_runtime})|" \ f"({block_scaled_filter_regex_1sm})|" \ f"({block_scaled_filter_regex_2sm})" - elif arch == "101a" or arch == "101f": + elif arch in ["101a", "101f", + ]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ f"({sm100_mma_filter_regex_1sm_runtime})|" \ f"({sm100_mma_filter_regex_2sm_runtime})|" \ f"({block_scaled_filter_regex_1sm})|" \ f"({block_scaled_filter_regex_2sm})" - elif arch == "120a" or arch == "120f": + elif arch in ["120a", "120f"]: # blockscaled sm120_mma kernels blockscaled_sm120_mma_kernel_cta_tiles = [ @@ -403,18 +382,8 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode else: error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f" raise Exception(error_message) - - # Statically encoded kernels are still added to generated_kernels - # but are filtered out from the testing commands to reduce test duration. - # The mergeable_kernel_filter specifies the kernels that are already covered - # by the runtime datatype tests so that we safely mark them off - # without changing the test coverage. - mergeable_kernel_filter = f"({mergeable_sm100_mma_filter_regex_1sm})|" \ - f"({mergeable_sm100_mma_filter_regex_2sm})|" \ - f"({mergeable_block_scaled_filter_regex_1sm})|" \ - f"({mergeable_block_scaled_filter_regex_2sm})" - elif mode == "functional_L1": + elif mode == "functional_L1": sm100_mma_cluster_size = [ '0x0x1' # dynamic cluster ] @@ -486,10 +455,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") - if is_runtime_datatype_enabled: - mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter) - - kernel_filter_re = re.compile(kernel_filter) testcase_counter = 0 kernels_emitted = 0 @@ -517,12 +482,6 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode if 'f16_f16_f16_void_f16' not in kernel_name : continue - # Filter out the statically encoded tests which are - # covered by runtime datatype tests to avoid repetition. - if is_runtime_datatype_enabled and len(mergeable_kernel_filter_re.findall(kernel_name)) != 0: - continue - - kernels_emitted += 1 kernel_name_set.add(kernel_name) hashed_kernel_name = hash_cutlass_string(kernel_name) @@ -685,9 +644,18 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode gemm_op = "gemm" profiler_reference_computing_override = profiler_reference_computing + grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) + num_groups = 1 if "bstensorop" in kernel_name: profiler_reference_computing_override = "--mode=trace" + if grouped: + gemm_op = "grouped_gemm" + num_groups = 3 # small to limit test time in host block-scaled reference kernels + batch_count = 1 + elif "bstensorop" in kernel_name: gemm_op = "block_scaled_gemm" + elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + gemm_op = "blockwise_gemm" problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] @@ -704,7 +672,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'n' : n, 'k' : k, 'beta' : beta, - 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta) + 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups) }, "runtime_params": { 'ctas_per_mma_instruction' : ctas_per_mma_instruction, @@ -732,6 +700,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode f" --m={str(m)}" + f" --n={str(n)}" + f" --k={str(k)}" + + (f" --num_groups={str(num_groups)}" if grouped else "") + f" --cluster_m={str(cluster_shape_m)}" + f" --cluster_n={str(cluster_shape_n)}" + f" --cluster_k={str(cluster_shape_k)}" + @@ -739,7 +708,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode f" --cluster_n_fallback={str(cluster_n_fallback)}" + f" --cluster_k_fallback={str(cluster_k_fallback)}" + f" --beta={str(beta)}" + - f" --batch_count={str(batch_count)}" + + ("" if grouped else f" --batch_count={str(batch_count)}") + f" --swizzle_size={str(swizzle_size)}" + f" --verification-required={str(verification_required).lower()}" ] \ @@ -752,7 +721,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode testcase_metadata.append(json.dumps(metadata_dict)) testlist_csv_rows.append(testcase_metadata) testcase_counter += 1 - + alpha = 1.0 if dynamic_datatype: diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index f85e160f..0d037fe2 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -994,6 +994,12 @@ ${compile_guard_end} element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + alignment_c = get_tma_alignment(operation.C.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.C.alignment + alignment_d = get_tma_alignment(operation.D.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.D.alignment operation_name_str = operation.procedural_name() layout_a_str = LayoutTag[instance_layout_A] @@ -1103,8 +1109,8 @@ using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig 'stages': stage_count_string, 'align_a': str(operation.A.alignment), 'align_b': str(operation.B.alignment), - 'align_c': str(operation.C.alignment), - 'align_d': str(operation.C.alignment), + 'align_c': str(alignment_c), + 'align_d': str(alignment_d), 'transform_a': ComplexTransformTag[operation.A.complex_transform], 'transform_b': ComplexTransformTag[operation.B.complex_transform], 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 58b605ad..0cdb2155 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -112,7 +112,6 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): cuda_version.append(x) return cuda_version >= [major, minor, patch] - ################################################################################################### ################################################################################################### @@ -6769,8 +6768,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): }, ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm math_instructions_1sm = [ # tf32 -> f32 MathInstruction( @@ -6793,8 +6793,8 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1] , DynamicClusterShape ] - - if 101 in manifest.compute_capabilities : + + if thor_sm in manifest.compute_capabilities_baseline: cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] , DynamicClusterShape ] @@ -6847,7 +6847,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] , DynamicClusterShape ] @@ -6887,8 +6887,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm grouped = is_grouped(gemm_kind) math_instructions_1sm = [ @@ -6950,7 +6951,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1] , DynamicClusterShape ] @@ -7108,7 +7109,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] , DynamicClusterShape ] @@ -7152,7 +7153,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm else: epi_schedule = EpilogueScheduleType.ScheduleAuto - kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100 + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) @@ -7201,8 +7202,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 grouped = is_grouped(gemm_kind) @@ -7270,7 +7272,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] , DynamicClusterShape ] @@ -7398,11 +7400,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ ( data_type["d_type"] == DataType.e5m2 ): continue - # don't support runtime data type for grouped yet - if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8): - continue - kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 - epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) @@ -7484,7 +7483,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] , DynamicClusterShape ] @@ -7607,9 +7606,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ ( data_type["d_type"] == DataType.e5m2 ): continue - # don't support runtime data type for grouped yet - if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8): - continue if grouped: epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm @@ -7617,7 +7613,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm else: epi_schedule = EpilogueScheduleType.ScheduleAuto - kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100 + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) @@ -7852,9 +7848,6 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, if (is_runtime_datatype_a != is_runtime_datatype_b): continue - # grouped GEMM does not support runtime data type yet - if grouped and (is_runtime_datatype_a or is_runtime_datatype_b): - continue kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, @@ -7896,8 +7889,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): TileSchedulerType.Default, TileSchedulerType.StreamK ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 math_instructions_1sm = [] @@ -7949,7 +7943,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [2,1,1], [1,1,1] @@ -8025,7 +8019,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1] , DynamicClusterShape @@ -8131,8 +8125,9 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud else: return [TileSchedulerType.Default, TileSchedulerType.StreamK] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 math_instructions_1sm = [] @@ -8184,7 +8179,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [1,1,1], [2,1,1] @@ -8264,7 +8259,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1], [4,1,1] @@ -8372,6 +8367,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], ] instruction_sizes_1sm = [ @@ -8400,8 +8396,9 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio else: return [TileSchedulerType.Default, TileSchedulerType.StreamK] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 math_instructions_1sm = [] @@ -8416,10 +8413,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if (is_runtime_datatype_a != is_runtime_datatype_b): continue - # grouped GEMM does not support runtime data type yet - if grouped and (is_runtime_datatype_a or is_runtime_datatype_b): - continue - math_instructions_1sm.append( MathInstruction( instr_size, @@ -8447,10 +8440,6 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if (is_runtime_datatype_a != is_runtime_datatype_b): continue - # grouped GEMM does not support runtime data type yet - if grouped and (is_runtime_datatype_a or is_runtime_datatype_b): - continue - math_instructions_2sm.append( MathInstruction( instr_size, @@ -8477,7 +8466,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [1,1,1], [2,1,1] @@ -8575,8 +8564,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio for layout in layouts: for data_type in data_types: - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 # E2M1 x E2M1, vector size 16, UE4M3 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 @@ -8604,7 +8596,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1], [4,1,1] @@ -8701,8 +8693,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio for layout in layouts: for data_type in data_types: - if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 @@ -8737,8 +8732,9 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 @@ -8763,7 +8759,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] , DynamicClusterShape ] @@ -8867,7 +8863,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] , DynamicClusterShape ] @@ -8952,8 +8948,9 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm tile_schedulers = [ TileSchedulerType.Default, ] @@ -9009,7 +9006,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_1sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -9040,7 +9037,7 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_2sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -9077,8 +9074,9 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm tile_schedulers = [ TileSchedulerType.Default, ] @@ -9134,7 +9132,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_1sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -9165,7 +9163,7 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_2sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -9202,8 +9200,9 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm tile_schedulers = [ TileSchedulerType.Default, @@ -9259,7 +9258,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_1sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -9290,7 +9289,7 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_2sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -9327,8 +9326,9 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm tile_schedulers = [ TileSchedulerType.Default, ] @@ -9389,7 +9389,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_1sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -9424,7 +9424,7 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_2sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -9465,8 +9465,9 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm tile_schedulers = [ TileSchedulerType.Default, ] @@ -9525,7 +9526,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_1sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_1sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape @@ -9587,7 +9588,7 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): for math_inst in math_instructions_2sm: tile_descriptions = [] for cluster_shape in sm100_cluster_shape_2sm: - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : if cluster_shape == [4,4,1] : continue multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) @@ -9677,8 +9678,9 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): } ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm math_instructions_1sm = [ MathInstruction( [128, 256, 8], @@ -9692,7 +9694,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [1,2,1], [1,1,1], [1,4,1] , DynamicClusterShape @@ -9732,7 +9734,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1], [2,2,1], [2,4,1], [4,1,1] , DynamicClusterShape @@ -9770,8 +9772,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm math_instructions_1sm = [ MathInstruction( [128, 256, 16], @@ -9784,7 +9787,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [1,2,1], [1,1,1] , DynamicClusterShape @@ -9858,7 +9861,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1], [2,2,1], [2,4,1], [4,1,1] , DynamicClusterShape @@ -9931,8 +9934,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], ] + thor_sm = 101 min_cc = 100 - max_cc = 101 + max_cc = thor_sm epi_type = DataType.f32 math_instructions_1sm = [ @@ -9947,7 +9951,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [ [1,2,1], [2,1,1], [1,1,1] , DynamicClusterShape @@ -10005,7 +10009,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): , DynamicClusterShape ] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [ [2,1,1], [2,2,1], [2,4,1], [4,1,1] , DynamicClusterShape @@ -10080,8 +10084,9 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return + thor_sm = 101 minimum_compute_capability = 100 - maximum_compute_capability = 101 + maximum_compute_capability = thor_sm spatial_dims = [2, 3] @@ -10110,7 +10115,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] # tile_descriptions is a 2-level list. @@ -10176,7 +10181,7 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, data_types_and_instruction_shapes_2sm) cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] for math_inst, output_type in math_instructions_w_output_2sm: @@ -10233,8 +10238,9 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return + thor_sm = 101 minimum_compute_capability = 100 - maximum_compute_capability = 101 + maximum_compute_capability = thor_sm spatial_dims = [2, 3] stages = 0 # zero means "deduce the number of stages automatically" @@ -10258,7 +10264,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, data_types_and_instruction_shapes_1sm) cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] for math_inst, output_type in math_instructions_w_output_1sm: @@ -10323,7 +10329,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, data_types_and_instruction_shapes_2sm) cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] - if 101 in manifest.compute_capabilities : + if thor_sm in manifest.compute_capabilities_baseline : cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] for math_inst, output_type in math_instructions_w_output_2sm: @@ -10704,9 +10710,9 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): ab_types_mxf8f6f4 = [ DataType.e2m1, - DataType.e2m3, + #DataType.e2m3, DataType.e3m2, - DataType.e5m2, + #DataType.e5m2, DataType.e4m3, ] @@ -10783,13 +10789,145 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): tile_schedulers = tile_schedulers(kernel_schedule), gemm_kind = GemmKind.SparseUniversal3x) +def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]] + ] + + cooperative_tile_sizes = [ + [128, 128, 128] + ] + pingpong_tile_sizes = [ + [64, 128, 128] + ] + + def get_tile_sizes(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return pingpong_tile_sizes + return cooperative_tile_sizes + + def get_warp_count(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return [2, 2, 1] + return [4, 2, 1] + + def get_sf_sizes(tile_size): + sf_sizes = [] + for vec_m in [1, 128]: + if tile_size[0] % vec_m > 0: + continue + for vec_n in [1, 128]: + if tile_size[1] % vec_m > 0: + continue + sf_sizes.append( + [vec_m, vec_n, 128] + ) + return sf_sizes + + cluster_shape = [1,1,1] + + acc_types = [ DataType.f32 ] + + instruction_sizes = [ + [16, 8, 32] + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 120 + + kernel_schedulers = [ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120 + ] + + ab_types = [ + [DataType.e4m3, DataType.e4m3], + [DataType.e4m3, DataType.e5m2] + ] + + math_instructions = [] + + for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types): + a_type, b_type = ab_type + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for kernel_schedule in kernel_schedulers: + tile_sizes = get_tile_sizes(kernel_schedule) + warp_count = get_warp_count(kernel_schedule) + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + sf_sizes = get_sf_sizes(tile_size) + for sf_size in sf_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape, + explicit_vector_sizes=sf_size) + ) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type in data_types: + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = gemm_kind) def GenerateSM100(manifest, cuda_version): + arch_family_cc = ['100f', '101f'] # # Dense Gemm # - architectures = manifest.args.architectures.split(';') if len(args.architectures) else ['50',] - GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) @@ -10797,7 +10935,7 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version) - if '100f' not in architectures and '101f' not in architectures: + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) @@ -10819,7 +10957,7 @@ def GenerateSM100(manifest, cuda_version): # GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version) GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version) - if '100f' not in architectures and '101f' not in architectures: + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version) GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version) GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) @@ -10849,6 +10987,8 @@ def GenerateSM120(manifest, cuda_version): # Sparse Gemm # GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) ################################################################################################### @@ -11328,13 +11468,17 @@ if __name__ == "__main__": GenerateSM80(manifest, args.cuda_version) GenerateSM89(manifest, args.cuda_version) GenerateSM90(manifest, args.cuda_version) - - blackwell_enabled_arch = any(arch in ["100a", "100f", "101a", "101f", "120a", "120f"] for arch in archs) + + blackwell_arch_list = [ + "100a", "100f", + "101a", "101f", + "120a", "120f" + ] + blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) if blackwell_enabled_arch: GenerateSM100(manifest, args.cuda_version) GenerateSM120(manifest, args.cuda_version) - if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 8e932cb3..17e6c5ce 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -345,6 +345,15 @@ def get_real_from_complex(complex_type): return r return DataType.invalid +# TMA requires an alignment of 128 bits for all data types +def get_tma_alignment(data_type): + if data_type == DataType.void: + return 0 + elif DataTypeSize[data_type] == 6: + return 128 # 96B alignment for 16U6 format + else: + return 128 // DataTypeSize[data_type] + # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() @@ -546,6 +555,9 @@ class KernelScheduleType(enum.Enum): F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() + BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() + BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() + KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', @@ -614,7 +626,10 @@ KernelScheduleTag = { KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120', KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120', - KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120' + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', } # @@ -685,7 +700,10 @@ KernelScheduleSuffixes = { KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', - KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q' + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q' } class EpilogueScheduleType(enum.Enum): @@ -756,6 +774,20 @@ EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', } +# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) +def is_tma_epilogue(epilogue_schedule_type): + return epilogue_schedule_type in [ + EpilogueScheduleType.ScheduleAuto, + EpilogueScheduleType.TmaWarpSpecialized, + EpilogueScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + ] + def to_grouped_schedule(schedule, grouped): if not grouped: return schedule @@ -771,17 +803,18 @@ def to_grouped_schedule(schedule, grouped): EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, # SM100 + KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, + KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, - EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, - EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, - + EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, } return group_schedule_map[schedule] diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index d10ec125..74eea09a 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -507,7 +507,8 @@ class Manifest: self.selected_kernels = [] self.ignore_kernel_names = [] self.exclude_kernel_names = [] - self.compute_capabilities = [50,] + self.compute_capabilities_baseline = [50,] + self.compute_capabilities_feature_set = ['50',] self.curr_build_dir = '.' self.filter_by_cc = True @@ -518,21 +519,9 @@ class Manifest: # A common user error is to use commas instead of semicolons. if ',' in args.architectures: raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) - architectures = args.architectures.split(';') if len(args.architectures) else ['50',] - - arch_conditional_cc = [ - '90a', - '100a', - '100f', - '101a', - '101f', - '120a', - '120f' - ] - architectures = [x if x not in arch_conditional_cc else x.split('a')[0] for x in architectures] - architectures = [x if x not in arch_conditional_cc else x.split('f')[0] for x in architectures] - - self.compute_capabilities = [int(x) for x in architectures] + + self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] + self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) if args.filter_by_cc in ['false', 'False', '0']: self.filter_by_cc = False @@ -597,7 +586,7 @@ class Manifest: return default_level - def get_kernel_filters (self, kernelListFile): + def get_kernel_filters(self, kernelListFile): if os.path.isfile(kernelListFile): with open(kernelListFile, 'r') as fileReader: lines = [line.rstrip() for line in fileReader if not line.startswith("#")] @@ -635,7 +624,7 @@ class Manifest: # filter based on compute capability enabled = not (self.filter_by_cc) - for cc in self.compute_capabilities: + for cc in self.compute_capabilities_baseline: if cc >= operation.tile_description.minimum_compute_capability and \ cc <= operation.tile_description.maximum_compute_capability and \ @@ -789,14 +778,14 @@ class Manifest: return name.endswith(".cpp") def get_src_archs_str_given_requested_cuda_archs(archs, source_file): - intersected_archs = archs & set(self.compute_capabilities) + intersected_archs = archs & set(self.compute_capabilities_baseline) if intersected_archs == set(): raise RuntimeError( """ Empty archs set for file {} after taking the intersection of {} (global requested archs) and {} (per file requested archs) - """.format(source_file, set(self.compute_capabilities), archs)) + """.format(source_file, set(self.compute_capabilities_baseline), archs)) else: return " ".join(map(str, intersected_archs)) diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index cc151e1a..997ded28 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -52,7 +52,6 @@ set(header_files_to_check cute/swizzle_layout.hpp cute/tensor.hpp cute/tensor_impl.hpp - cute/tensor_predicate.hpp cute/underscore.hpp # cute/algorithm cute/algorithm/axpby.hpp diff --git a/test/unit/conv/device_3x/dgrad/CMakeLists.txt b/test/unit/conv/device_3x/dgrad/CMakeLists.txt index ede0ad9a..cb7abd30 100644 --- a/test/unit/conv/device_3x/dgrad/CMakeLists.txt +++ b/test/unit/conv/device_3x/dgrad/CMakeLists.txt @@ -30,6 +30,8 @@ add_custom_target( cutlass_test_unit_conv_dgrad_device DEPENDS cutlass_test_unit_conv_dgrad_device_tensorop_sm90 + cutlass_test_unit_conv_dgrad_device_tensorop_sm100 + cutlass_test_unit_conv_dgrad_device_tensorop_sm100_fusion ) cutlass_test_unit_add_executable( @@ -47,3 +49,43 @@ cutlass_test_unit_add_executable( sm90_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu ) +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +set(cutlass_test_unit_conv_dgrad_device_tensorop_sm100_kernels + sm100_conv2d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu + sm100_conv2d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu + sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu + sm100_conv2d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu + + sm100_conv3d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu + sm100_conv3d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu + sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu + sm100_conv3d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu + + sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + + sm100_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu + sm100_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu + sm100_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu +) + +# Add the executable +cutlass_test_unit_add_executable( + cutlass_test_unit_conv_dgrad_device_tensorop_sm100 + ${cutlass_test_unit_conv_dgrad_device_tensorop_sm100_kernels} +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_conv_dgrad_device_tensorop_sm100_fusion + + sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu + sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu + + sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu +) + +endif() diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..e9286314 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..61056ba7 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + ElementOut, cutlass::layout::TensorNWC, 128 / sizeof_bits_v, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..94abb4aa --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv1d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_dgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..302e9ea7 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..48a68f1f --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..1ea6c9be --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu new file mode 100644 index 00000000..9f36c034 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_bf16nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu new file mode 100644 index 00000000..6f4598c1 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu new file mode 100644 index 00000000..967c61d1 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// per-channel alpha/beta scaling && bias && relu +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f16nhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu new file mode 100644 index 00000000..8ea8ba40 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f32nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu new file mode 100644 index 00000000..d613c4da --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv2d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_dgrad_implicitgemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..de3a34b1 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..933b8e22 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..cd6ebff5 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu new file mode 100644 index 00000000..81fe86a9 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_bf16_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_bf16ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::bfloat16_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu new file mode 100644 index 00000000..d456dc90 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu new file mode 100644 index 00000000..690fc0fe --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f16_tensorop_f32_with_fusion.cu @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// per-channel alpha/beta scaling && bias && relu +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f16ndhwc_tensor_op_f32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::half_t; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu new file mode 100644 index 00000000..dc633a62 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu new file mode 100644 index 00000000..a90616b9 --- /dev/null +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f8_f8_f8_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_dgrad_implicitgemm_f8ndhwc_f8ndhwc_f8ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::float_e4m3_t; + using ElementFlt = cutlass::float_e4m3_t; + using ElementOut = cutlass::float_e4m3_t; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kDgrad, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/CMakeLists.txt b/test/unit/conv/device_3x/fprop/CMakeLists.txt index 90b2cef2..4091595c 100644 --- a/test/unit/conv/device_3x/fprop/CMakeLists.txt +++ b/test/unit/conv/device_3x/fprop/CMakeLists.txt @@ -32,6 +32,8 @@ add_custom_target( cutlass_test_unit_conv1d_fprop_device_tensorop_sm90 cutlass_test_unit_conv2d_fprop_device_tensorop_sm90 cutlass_test_unit_conv3d_fprop_device_tensorop_sm90 + cutlass_test_unit_conv_fprop_device_tensorop_sm100 + cutlass_test_unit_conv_fprop_device_tensorop_sm100_fusion ) cutlass_test_unit_add_executable( @@ -73,3 +75,50 @@ cutlass_test_unit_add_executable( sm90_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu ) +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +cutlass_test_unit_add_executable( + cutlass_test_unit_conv_fprop_device_tensorop_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu + sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu + sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu + + sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu + + sm100_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu + sm100_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu + sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu + + sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu + sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu + sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_conv_fprop_device_tensorop_sm100_fusion + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu + sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu + sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu + + sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + + sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu + sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu + sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu +) + +endif() diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..3420d279 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..a39cca01 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..b344c677 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu new file mode 100644 index 00000000..9cc828dc --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 128x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 128x128x64_1x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 256x64x64_2x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 256x128x64_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu new file mode 100644 index 00000000..1c5a6493 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// per-channel alpha/beta scaling && bias && relu +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.05f)); +} + +// alpha != 1 && beta != 0 && bias && gelu_erf +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu_erf) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && swish +TEST(SM100_device_conv1d_fprop_implicitgemm_s8nwc_s8nwc_s32nwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_swish) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::SiLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu new file mode 100644 index 00000000..90242d7e --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 64x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 128x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x32 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 128x128x32_1x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x32 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 256x64x32_2x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x32 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 256x128x32_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x32 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu new file mode 100644 index 00000000..e8ccb681 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv1d_fprop_implicitgemm_tf32nwc_tf32nwc_f32nwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias_relu) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNWC, 4, + float, cutlass::layout::TensorNWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNWC, 4, + ElementFlt, cutlass::layout::TensorNWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..6f8a040b --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..79b9d8b6 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..83ffd92a --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu new file mode 100644 index 00000000..59980588 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu @@ -0,0 +1,339 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 128x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 128x128x64_1x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 256x64x64_2x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 256x128x64_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu new file mode 100644 index 00000000..ba2dacd0 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// per-channel alpha/beta scaling && bias && relu +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && gelu_erf +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu_erf) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && swish +TEST(SM100_device_conv2d_fprop_implicitgemm_s8nhwc_s8nhwc_s32nhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_swish) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::SiLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu new file mode 100644 index 00000000..f496d98b --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 64x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 128x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x32 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 128x128x32_1x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x32 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 256x64x32_2x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x32 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 256x128x32_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x32 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu new file mode 100644 index 00000000..49e13acf --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv2d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv2d_fprop_implicitgemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias_relu) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNHWC, 4, + float, cutlass::layout::TensorNHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNHWC, 4, + ElementFlt, cutlass::layout::TensorNHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..6b32e145 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..54570dd1 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,331 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && HardSwish +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_hardswish) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ScaledHardSwish, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && leakyrelu +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta_bias_leakyrelu) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using ElementBias = cutlass::half_t; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::LeakyReLU, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..b98ed8eb --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu new file mode 100644 index 00000000..3042989e --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 128x64x64_1x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 128x128x64_1x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 256x64x64_2x1x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 256x128x64_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu new file mode 100644 index 00000000..42c789dc --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32_with_fusion.cu @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// per-channel alpha/beta scaling && bias && relu +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_scaled_bias_relu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// alpha != 1 && beta != 0 && bias && gelu +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU_taylor, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && gelu_erf +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_gelu_erf) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::GELU, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && swish +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_swish) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::SiLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +// alpha != 1 && beta != 0 && bias && leakyrelu +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_leakyrelu) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::LeakyReLU, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + + +// alpha != 1 && beta != 0 && bias && hardswish +TEST(SM100_device_conv3d_fprop_implicitgemm_s8ndhwc_s8ndhwc_s32ndhwc_tensor_op_s32, 64x64x64_1x1x1_alpha_beta_bias_hardswish) { + using ElementAct = int8_t; + using ElementFlt = int8_t; + using ElementOut = int32_t; + using ElementAcc = int32_t; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ScaledHardSwish, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + int8_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + int32_t, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0f, 1.0f, 0.005f)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu new file mode 100644 index 00000000..5145019f --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 64x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x32 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 128x64x32_1x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x32 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 128x128x32_1x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x32 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 256x64x32_2x1x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x32 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 256x128x32_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, _64, Shape<_32>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x32 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorNDHWC, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementAct), + ElementFlt, cutlass::layout::TensorNDHWC, 16 / sizeof(ElementFlt), + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu new file mode 100644 index 00000000..1bb0e9d4 --- /dev/null +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32_with_fusion.cu @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBias< + ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +// alpha != 1 && beta != 0 && bias && relu +TEST(SM100_device_conv3d_fprop_implicitgemm_tf32ndhwc_tf32ndhwc_f32ndhwc_tensor_op_f32, 64x64x32_1x1x1_alpha_beta_bias_relu) { + using ElementAct = float; + using ElementFlt = float; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using ElementBias = float; + using MmaTileShape = Shape<_64, _64, Shape<_32>>; + using ClusterShape = Shape<_1,_1,_1>; + + using FusionOperation = cutlass::epilogue::fusion::LinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLu, ElementOut, ElementCompute, ElementBias>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + float, cutlass::layout::TensorNDHWC, 4, + float, cutlass::layout::TensorNDHWC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kFprop, + ElementAct, cutlass::layout::TensorNDHWC, 4, + ElementFlt, cutlass::layout::TensorNDHWC, 4, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index ddf8ea40..031e9a91 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -172,6 +172,8 @@ struct ConvTestbed { static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && !cute::is_same_v; + static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; + static constexpr bool DisableSource = cute::is_void_v; using StrideC = typename Conv::ConvKernel::StrideC; @@ -213,10 +215,24 @@ struct ConvTestbed { tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + if constexpr (IsPerChannelScaleEnabled) { + tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + } initialize_values(tensor_A, init_A, seed); initialize_values(tensor_B, init_B, seed * 11); initialize_values(tensor_C, init_C, seed * 17); initialize_values(tensor_bias, init_bias, seed * 19); + if constexpr (IsPerChannelScaleEnabled) { + initialize_values(tensor_alpha, init_bias, seed * 23); + if constexpr (DisableSource) { + initialize_values(tensor_beta, init_disable, seed * 27); + } + else { + initialize_values(tensor_beta, init_bias, seed * 27); + } + } + bool flag = true; if constexpr (isSparseEnabled) { flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); @@ -314,8 +330,9 @@ struct ConvTestbed { bool run( ProblemShape const& problem_shape, ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0) - , + ElementScalar beta = ElementScalar(0), + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0), RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, Splits splits = Splits{}, @@ -341,6 +358,9 @@ struct ConvTestbed { cudaGetDevice(&hw_info.device_id); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.cluster_shape = cluster_shape; + hw_info.cluster_shape_fallback = cluster_shape_fallback; + // configure the operator Conv conv_op; auto stride_C = StrideC{}; @@ -392,6 +412,11 @@ struct ConvTestbed { fusion_args.alpha = alpha; fusion_args.beta = beta; + if constexpr (IsPerChannelScaleEnabled) { + fusion_args.alpha_ptr = tensor_alpha.data().get(); + fusion_args.beta_ptr = tensor_beta.data().get(); + } + if constexpr (IsBiasEnabled) { fusion_args.bias_ptr = tensor_bias.data().get(); } @@ -478,6 +503,11 @@ struct ConvTestbed { epilogue_fusion_params.alpha = alpha; epilogue_fusion_params.beta = beta; + if constexpr (IsPerChannelScaleEnabled) { + epilogue_fusion_params.tensor_alpha = mAlpha; + epilogue_fusion_params.tensor_beta = mBeta; + } + if constexpr (IsBiasEnabled) { epilogue_fusion_params.tensor_bias = mBias; } @@ -638,6 +668,16 @@ struct ConvTestbed { for (size_t i = 0; i < size_t(size(B)); ++i) { printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); } + if constexpr (IsPerChannelScaleEnabled) { + for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { + printf("[%llu]: alpha = %f\n", static_cast(i), + float(tensor_alpha(i))); + } + for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { + printf("[%llu]: beta = %f\n", static_cast(i), + float(tensor_beta(i))); + } + } if constexpr (IsBiasEnabled) { for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { printf("[%llu]: bias = %f\n", static_cast(i), @@ -657,7 +697,9 @@ struct ConvTestbed { ///////////////////////////////////////////////////////////////////////////////////////////////// template -bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0) ) { using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; @@ -697,8 +739,10 @@ bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f passed = testbed.run( conv_problem, cutlass::from_real(alpha), - cutlass::from_real(beta) - ,RasterOrderOptions::Heuristic, // raster_order + cutlass::from_real(beta), + cluster_shape, + cluster_shape_fallback, + RasterOrderOptions::Heuristic, // raster_order MaxSwizzleSize(1), splits, decomp_mode diff --git a/test/unit/conv/device_3x/wgrad/CMakeLists.txt b/test/unit/conv/device_3x/wgrad/CMakeLists.txt index 7d7d310c..f1c4041b 100644 --- a/test/unit/conv/device_3x/wgrad/CMakeLists.txt +++ b/test/unit/conv/device_3x/wgrad/CMakeLists.txt @@ -30,6 +30,8 @@ add_custom_target( cutlass_test_unit_conv_wgrad_device DEPENDS cutlass_test_unit_conv_wgrad_device_tensorop_sm90 + cutlass_test_unit_conv_wgrad_device_tensorop_sm100 + cutlass_test_unit_conv_wgrad_device_tensorop_sm100_fusion ) cutlass_test_unit_add_executable( @@ -44,3 +46,26 @@ cutlass_test_unit_add_executable( sm90_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu ) +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +cutlass_test_unit_add_executable( + cutlass_test_unit_conv_wgrad_device_tensorop_sm100 + + sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu + + sm100_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu + sm100_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu + #sm100_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu +) + +cutlass_test_unit_add_executable_split_file( + cutlass_test_unit_conv_wgrad_device_tensorop_sm100_fusion + + sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu + sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu +) + +endif() diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..c88f407b --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..793b1e12 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..098a7fb1 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f32nwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCS, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNWC, 8, + ElementFlt, cutlass::layout::TensorNWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..d5dda5d1 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..5d17afb3 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..79875687 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128x64_1x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x64x64_2x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSR, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNHWC, 8, + ElementFlt, cutlass::layout::TensorNHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu new file mode 100644 index 00000000..042414db --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16.cu @@ -0,0 +1,338 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +//TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 128x128x64_1x2x1) { +// using ElementAct = cutlass::half_t; +// using ElementFlt = cutlass::half_t; +// using ElementOut = cutlass::half_t; +// using ElementAcc = cutlass::half_t; +// using ElementCompute = float; +// using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; +// using ClusterShape = Shape<_1,_2,_1>; +// +// using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// MmaTileShape, ClusterShape, +// cutlass::epilogue::collective::EpilogueTileAuto, +// ElementAcc, ElementCompute, +// ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// cutlass::epilogue::NoSmemWarpSpecialized1Sm +// >::CollectiveOp; +// +// using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// cutlass::conv::Operator::kWgrad, +// ElementAct, cutlass::layout::TensorNDHWC, 8, +// ElementFlt, cutlass::layout::TensorNDHWC, 8, +// ElementAcc, +// MmaTileShape, ClusterShape, +// cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, +// cutlass::conv::collective::KernelScheduleAuto +// >::CollectiveOp; +// +// using ProblemShape=cutlass::conv::ConvProblemShape; +// using ConvKernel = cutlass::conv::kernel::ConvUniversal< ProblemShape, +// ProblemShape, +// CollectiveMainloop, +// CollectiveEpilogue +// >; +// +// using Conv = cutlass::conv::device::ConvUniversalAdapter; +// +// EXPECT_TRUE(test::conv::device::TestAllConv()); +//} +// +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +//TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x64x64_2x1x1) { +// using ElementAct = cutlass::half_t; +// using ElementFlt = cutlass::half_t; +// using ElementOut = cutlass::half_t; +// using ElementAcc = cutlass::half_t; +// using ElementCompute = float; +// using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; +// using ClusterShape = Shape<_2,_1,_1>; +// +// using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// MmaTileShape, ClusterShape, +// cutlass::epilogue::collective::EpilogueTileAuto, +// ElementAcc, ElementCompute, +// ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// cutlass::epilogue::NoSmemWarpSpecialized2Sm +// >::CollectiveOp; +// +// using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// cutlass::conv::Operator::kWgrad, +// ElementAct, cutlass::layout::TensorNDHWC, 8, +// ElementFlt, cutlass::layout::TensorNDHWC, 8, +// ElementAcc, +// MmaTileShape, ClusterShape, +// cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, +// cutlass::conv::collective::KernelScheduleAuto +// >::CollectiveOp; +// +// using ProblemShape=cutlass::conv::ConvProblemShape; +// using ConvKernel = cutlass::conv::kernel::ConvUniversal< ProblemShape, +// ProblemShape, +// CollectiveMainloop, +// CollectiveEpilogue +// >; +// +// using Conv = cutlass::conv::device::ConvUniversalAdapter; +// +// EXPECT_TRUE(test::conv::device::TestAllConv()); +//} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu new file mode 100644 index 00000000..2868f127 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_with_fusion.cu @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// alpha != 1 && beta != 0 +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_1x1x1_alpha_beta) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = cutlass::half_t; + using ElementAcc = cutlass::half_t; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ProblemShape=cutlass::conv::ConvProblemShape; + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(2.0, 1.0)); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu new file mode 100644 index 00000000..41319425 --- /dev/null +++ b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f32.cu @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide CONV interface +*/ + +#include "cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/conv/device/conv_universal_adapter.hpp" +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/conv/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../testbed_conv.hpp" +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Static cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Cluster tile shape 64x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x64x64 +// Cluster shape 1x1x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x64x64_1x1x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_1,_1,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +// +// Cluster tile shape 128x128x64 +// Cluster shape 1x2x1 +// +//TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 128x128x64_1x2x1) { +// using ElementAct = cutlass::half_t; +// using ElementFlt = cutlass::half_t; +// using ElementOut = float; +// using ElementAcc = float; +// using ElementCompute = float; +// using MmaTileShape = Shape<_128, Shape<_64>, Shape<_64>>; +// using ClusterShape = Shape<_1,_2,_1>; +// +// using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// MmaTileShape, ClusterShape, +// cutlass::epilogue::collective::EpilogueTileAuto, +// ElementAcc, ElementCompute, +// ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// cutlass::epilogue::NoSmemWarpSpecialized1Sm +// >::CollectiveOp; +// +// using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// cutlass::conv::Operator::kWgrad, +// ElementAct, cutlass::layout::TensorNDHWC, 8, +// ElementFlt, cutlass::layout::TensorNDHWC, 8, +// ElementAcc, +// MmaTileShape, ClusterShape, +// cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, +// cutlass::conv::collective::KernelScheduleAuto +// >::CollectiveOp; +// +// using ConvKernel = cutlass::conv::kernel::ConvUniversal< +// CollectiveMainloop, +// CollectiveEpilogue +// >; +// +// using Conv = cutlass::conv::device::ConvUniversalAdapter; +// +// EXPECT_TRUE(test::conv::device::TestAllConv()); +//} +// +// +// Cluster tile shape 256x64x64 +// Cluster shape 2x1x1 +// +//TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x64x64_2x1x1) { +// using ElementAct = cutlass::half_t; +// using ElementFlt = cutlass::half_t; +// using ElementOut = float; +// using ElementAcc = float; +// using ElementCompute = float; +// using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; +// using ClusterShape = Shape<_2,_1,_1>; +// +// using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// MmaTileShape, ClusterShape, +// cutlass::epilogue::collective::EpilogueTileAuto, +// ElementAcc, ElementCompute, +// ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, +// cutlass::epilogue::NoSmemWarpSpecialized2Sm +// >::CollectiveOp; +// +// using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< +// cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, +// cutlass::conv::Operator::kWgrad, +// ElementAct, cutlass::layout::TensorNDHWC, 8, +// ElementFlt, cutlass::layout::TensorNDHWC, 8, +// ElementAcc, +// MmaTileShape, ClusterShape, +// cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, +// cutlass::conv::collective::KernelScheduleAuto +// >::CollectiveOp; +// +// using ConvKernel = cutlass::conv::kernel::ConvUniversal< +// CollectiveMainloop, +// CollectiveEpilogue +// >; +// +// using Conv = cutlass::conv::device::ConvUniversalAdapter; +// +// EXPECT_TRUE(test::conv::device::TestAllConv()); +//} + +// +// Cluster tile shape 256x128x64 +// Cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 256x128x64_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_256, Shape<_64>, Shape<_64>>; + using ClusterShape = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv()); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic cluster +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CTA tile shape 64x64x64 +// preferred cluster shape 2x4x1 +// fallback cluster shape 2x2x1 +// +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32, 64x64x64_preferred_2x4x1_fallback_2x2x1) { + using ElementAct = cutlass::half_t; + using ElementFlt = cutlass::half_t; + using ElementOut = float; + using ElementAcc = float; + using ElementCompute = float; + using MmaTileShape = Shape<_64, Shape<_64>, Shape<_64>>; + using ClusterShape = decltype(make_shape(int(0), int(0), Int<1>{})); + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, ElementCompute, + ElementAct, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + ElementOut, cutlass::layout::TensorKCSRT, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::conv::Operator::kWgrad, + ElementAct, cutlass::layout::TensorNDHWC, 8, + ElementFlt, cutlass::layout::TensorNDHWC, 8, + ElementAcc, + MmaTileShape, ClusterShape, + cutlass::conv::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::conv::collective::KernelScheduleAuto + >::CollectiveOp; + + using ConvKernel = cutlass::conv::kernel::ConvUniversal< + CollectiveMainloop, + CollectiveEpilogue + >; + + using Conv = cutlass::conv::device::ConvUniversalAdapter; + + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index 4469f43e..71286671 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -50,6 +50,6 @@ cutlass_test_unit_add_executable( pointer.cpp reverse.cpp swizzle_layout.cpp - transform.cpp + tensor_algs.cpp tuple.cpp ) diff --git a/test/unit/cute/core/tensor_algs.cpp b/test/unit/cute/core/tensor_algs.cpp new file mode 100644 index 00000000..20f4666c --- /dev/null +++ b/test/unit/cute/core/tensor_algs.cpp @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include + +TEST(CuTe_algorithm, TensorTransform) { + using namespace cute; + complex array[4] = {{0,0}, {1,0}, {0,1}, {1,1}}; + complex correct[4] = {{0,0}, {1,0}, {0,-1}, {1,-1}}; + Tensor tensor = make_tensor(static_cast*>(array), make_layout(make_shape(4))); + conjugate conj; + transform(tensor, conj); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(tensor(i), correct[i]); + } +} + +TEST(CuTe_algorithm, TensorBatchReduce) { + using namespace cute; + + int src_vals[16] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; + Tensor src_tensor = make_tensor(static_cast(src_vals), + make_layout(make_shape (make_shape (2,2), make_shape (2,2)), + make_stride(make_stride(2,8), make_stride(1,4)))); + + array dst_vals; + fill(dst_vals, 0); + Tensor dst_tensor = make_tensor(dst_vals.begin(), make_shape(2,2)); + + batch_reduce(src_tensor, dst_tensor); + + int correct[4] = {20,24,36,40}; + for (int i = 0; i < 4; ++i) { + //printf("%d %d\n", dst_tensor(i), correct[i]); + EXPECT_EQ(dst_tensor(i), correct[i]); + } +} + + +TEST(CuTe_algorithm, TensorLogicalReduce) { + using namespace cute; + + { // Reduce each column of a matrix + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout>, + Stride< _1, Stride<_64,_1>>>{}); + auto slicer = make_coord(0_c, _); + Tensor dst_tensor = make_tensor_like(src_tensor(slicer)); + + logical_reduce(src_tensor, dst_tensor, slicer); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), reduce(src_tensor(_,i), int(0))); + } + } + + { // Reduce each row of a matrix + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout>, + Stride< _1, Stride<_64,_1>>>{}); + auto slicer = make_coord(_, 0_c); + Tensor dst_tensor = make_tensor_like(src_tensor(slicer)); + + logical_reduce(src_tensor, dst_tensor, slicer); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), reduce(src_tensor(i,_), int(0))); + } + } + + { // 1 profile + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout, Stride<_1>>{}); + array dst_vals; + fill(dst_vals, 0); + Tensor dst_tensor = make_tensor(dst_vals.begin(), Layout<_1,_0>{}); + + logical_reduce(src_tensor, dst_tensor, 1); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), reduce(src_tensor, int(0))); + } + } + + { // _ profile + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout, Stride<_1>>{}); + auto slicer = _; + Tensor dst_tensor = make_tensor_like(src_tensor(slicer)); + + logical_reduce(src_tensor, dst_tensor, slicer); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), src_tensor(i)); + } + } + + { // (1,1) profile + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout>, + Stride< _1, Stride<_192,_32>>>{}); + auto slicer = make_coord(1, 1); + array dst_vals; + fill(dst_vals, 0); + Tensor dst_tensor = make_tensor(dst_vals.begin(), Layout<_1,_0>{}); + + logical_reduce(src_tensor, dst_tensor, slicer); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), reduce(src_tensor, int(0))); + } + } + + { // (_,_) profile + Tensor src_tensor = make_tensor(counting_iterator{0}, + Layout>, + Stride< _1, Stride<_192,_32>>>{}); + auto slicer = make_coord(_,_); + Tensor dst_tensor = make_tensor_like(src_tensor(slicer)); + + logical_reduce(src_tensor, dst_tensor, slicer); + + for (int i = 0; i < size(dst_tensor); ++i) { + EXPECT_EQ(dst_tensor(i), src_tensor(i)); + } + } + + { + Tensor src_tensor = make_tensor(counting_iterator{0}, + make_layout(make_shape (2,2,2,2), + make_stride(1,2,4,8))); + + array dst_vals; + fill(dst_vals, 0); + Tensor dst_tensor = make_tensor(dst_vals.begin(), make_shape(2,2)); + + auto target_profile = make_coord(_,1,_,1); + + logical_reduce(src_tensor, dst_tensor, target_profile); + + int correct[4] = {20,24,36,40}; + for (int i = 0; i < 4; ++i) { + //printf("%d %d\n", dst_tensor(i), correct[i]); + EXPECT_EQ(dst_tensor(i), correct[i]); + } + } + + { + Tensor src_tensor = make_tensor(counting_iterator{0}, + make_layout(make_shape (2,make_shape (2,2),2), + make_stride(1,make_stride(2,4),8))); + + array dst_vals; + fill(dst_vals, 0); + Tensor dst_tensor = make_tensor(dst_vals.begin(), make_shape(2,2)); + + auto target_profile = make_coord(_,make_coord(1,_),1); + + logical_reduce(src_tensor, dst_tensor, target_profile); + + int correct[4] = {20,24,36,40}; + for (int i = 0; i < 4; ++i) { + //printf("%d %d\n", dst_tensor(i), correct[i]); + EXPECT_EQ(dst_tensor(i), correct[i]); + } + } + +} diff --git a/test/unit/cute/core/transform.cpp b/test/unit/cute/core/transform.cpp deleted file mode 100644 index 81801d30..00000000 --- a/test/unit/cute/core/transform.cpp +++ /dev/null @@ -1,49 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cutlass_unit_test.h" - -#include -#include -#include - -TEST(CuTe_core, Transform) { - using namespace cute; - complex array[4] = {{0,0}, {1,0}, {0,1}, {1,1}}; - complex correct[4] = {{0,0}, {1,0}, {0,-1}, {1,-1}}; - auto tensor = make_tensor(static_cast*>(array), make_layout(make_shape(4))); - conjugate conj; - transform(tensor, conj); - for (int i = 0; i < 4; ++i) - { - EXPECT_EQ(tensor(i), correct[i]); - } -} diff --git a/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu b/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu index a0629fd2..8c86493d 100644 --- a/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu +++ b/test/unit/gemm/device/sm100_gemm_i8_i8_i8_tensor_op_s32_ptr_array.cu @@ -54,7 +54,7 @@ using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +#if (defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED)) /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////// 128x64x128 Cluster1x1x1 TMEM 4x1 //////////////////////////////////////////// @@ -263,5 +263,5 @@ TEST(SM100_Device_Gemm_s8t_s8n_s8n_tensorop_2cta_s32_ptr_array, 128x1024x128_2x4 EXPECT_TRUE(pass); } -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) && !defined(CUTLASS_SM100_FAMILY_ARCHS_ENABLED) diff --git a/test/unit/nvrtc/thread/nvrtc_contraction.cu b/test/unit/nvrtc/thread/nvrtc_contraction.cu index fa079434..bd742c69 100644 --- a/test/unit/nvrtc/thread/nvrtc_contraction.cu +++ b/test/unit/nvrtc/thread/nvrtc_contraction.cu @@ -51,6 +51,9 @@ TEST(SM90_nvrtc_kernel, Contraction) { "-std=c++17", "-arch=sm_90", "-I" CUDA_INCLUDE_DIR, +#if (__CUDACC_VER_MAJOR__ >= 13) + "-I" CUDA_INCLUDE_DIR "/cccl", +#endif // __CUDACC_VER_MAJOR__ >= 13 }; EXPECT_TRUE(test::nvrtc::thread::TestbedKernel::compile( @@ -60,7 +63,7 @@ TEST(SM90_nvrtc_kernel, Contraction) { "cute::Shape," "true, true," "10, 10, 10, 10>::Kernel", - { nvrtc_opts, nvrtc_opts + 5 } + { std::begin(nvrtc_opts), std::end(nvrtc_opts) } )); } #endif diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index ceb420f2..6764d9a6 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -576,6 +576,11 @@ struct GemmGroupedArguments { gemm::GemmCoord cluster_shape{}; gemm::GemmCoord cluster_shape_fallback{}; + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + int swizzle_size{1}; + // these should really be in the configuration but staying consistent with GEMM int sm_count{0}; int max_active_clusters{0}; diff --git a/tools/library/src/grouped_gemm_operation_3x.hpp b/tools/library/src/grouped_gemm_operation_3x.hpp index 7c67ae3a..91f618d4 100644 --- a/tools/library/src/grouped_gemm_operation_3x.hpp +++ b/tools/library/src/grouped_gemm_operation_3x.hpp @@ -64,6 +64,13 @@ public: using CollectiveEpilogue = typename Operator::CollectiveEpilogue; using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + GroupedGemmOperation3xBase(char const* name = "unknown_gemm") : GemmOperation3xBase(name, GemmKind::kGrouped) { this->description_.kind = OperationKind::kGroupedGemm; @@ -152,8 +159,65 @@ protected: arguments.problem_sizes_3x, arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host : nullptr}; - operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); - operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE5M2) { + return cute::UMMA::MXF8F6F4Format::E5M2; + } + else if (type == RuntimeDatatype::kE4M3) { + return cute::UMMA::MXF8F6F4Format::E4M3; + } + else if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } + else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } + else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; + #endif + return cute::UMMA::MXF8F6F4Format::E4M3; + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; + #endif + return cute::UMMA::MXF4Format::E2M1; + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + } operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); @@ -166,10 +230,29 @@ protected: operator_args.epilogue.dD = static_cast(this->strideD_device.data()); + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ operator_args.hw_info.sm_count = arguments.sm_count; if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments.raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { operator_args.hw_info.cluster_shape = dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); @@ -330,7 +413,6 @@ public: return status; } - // Set arguments that should only be set once before verifying or profiling the kernel. // This should encompass any expensive operations that don't vary from run to run // (e.g., max_active_clusters). @@ -363,9 +445,10 @@ public: cluster_dims, threads_per_block, kernel_ptr); - + if (args->max_active_clusters == 0) { - return Status::kErrorInternal; + std::cerr << "Max Active Clusters could not be queried. " + << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; } return Status::kSuccess; diff --git a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h index 6a4803c7..72360487 100644 --- a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -69,6 +69,12 @@ public: std::vector problem_sizes; std::vector> problem_sizes_3x; + /// For exploration purposes + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + int cluster_m{1}; int cluster_n{1}; int cluster_k{1}; @@ -83,6 +89,14 @@ public: std::vector alpha; std::vector beta; + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + bool use_pdl{false}; + /// Parses the problem Status parse( library::GroupedGemmDescription const& operation_desc, @@ -190,7 +204,7 @@ private: gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; /* Query device SM count to pass onto the kernel as an argument, where needed */ - arguments.sm_count = options.device.properties[0].multiProcessorCount; + arguments.sm_count = options.device.get_sm_count(0); if (is_block_scaled) { auto& block_scaled_ws = gemm_workspace_.block_scales.value(); arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); @@ -272,6 +286,15 @@ protected: library::GroupedGemmDescription const& operation_desc, ProblemSpace const& problem_space); + /// Update performance result configuration for exploration parameters + void update_result_( + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size); + /// Verifies CUTLASS against host and device references bool verify_with_reference_( Options const& options, @@ -292,6 +315,12 @@ protected: void* host_workspace, void* device_workspace) override; + /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape + bool profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space); + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/include/cutlass/profiler/options.h b/tools/profiler/include/cutlass/profiler/options.h index 28b1e2f8..0800e440 100644 --- a/tools/profiler/include/cutlass/profiler/options.h +++ b/tools/profiler/include/cutlass/profiler/options.h @@ -94,10 +94,15 @@ public: /// Total memory allocation on each device size_t maximum_capacity; + private: + /// SM Count + /// Limits the number of SMs to use on each device + int sm_count; + // // Methods // - + public: explicit Device(CommandLine const &cmdline); void print_usage(std::ostream &out) const; @@ -107,7 +112,10 @@ public: /// Returns the device ID from a device index int device_id(size_t device_index) const; - /// Returns the compute capability of the listed devices (e.g. 61, 60, 70, 75) + /// Returns the sm_count if set, otherwise returns the number of SMs on the device + int get_sm_count(int device_index) const; + + /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) int compute_capability(int device_index) const; }; diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu index bb8940b4..078acc96 100644 --- a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -907,7 +907,7 @@ Status BlockScaledGemmOperationProfiler::initialize_workspace( gemm_workspace_.arguments.use_pdl = problem_.use_pdl; /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0); } // diff --git a/tools/profiler/src/blockwise_gemm_operation_profiler.cu b/tools/profiler/src/blockwise_gemm_operation_profiler.cu index 5716869d..4a8e2543 100644 --- a/tools/profiler/src/blockwise_gemm_operation_profiler.cu +++ b/tools/profiler/src/blockwise_gemm_operation_profiler.cu @@ -749,7 +749,7 @@ Status BlockwiseGemmOperationProfiler::initialize_workspace( gemm_workspace_.arguments.use_pdl = problem_.use_pdl; /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + gemm_workspace_.arguments.sm_count = options.device.get_sm_count(0); } // diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index b0a72553..60088075 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -977,7 +977,7 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount; + gemm_workspace_[i].arguments.sm_count = options.device.get_sm_count(i); gemm_workspace_[i].arguments.device_index = static_cast(i); } } diff --git a/tools/profiler/src/grouped_gemm_operation_profiler.cu b/tools/profiler/src/grouped_gemm_operation_profiler.cu index ff702f6a..e7b02215 100644 --- a/tools/profiler/src/grouped_gemm_operation_profiler.cu +++ b/tools/profiler/src/grouped_gemm_operation_profiler.cu @@ -108,6 +108,14 @@ GroupedGemmOperationProfiler::GroupedGemmOperationProfiler(Options const& option {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta (applied to all GEMMs in group)."}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, + "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, + {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, + "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, + {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, + "Raster order (heuristic, along_n, along_m)"}, + {ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"}, + {ArgumentTypeID::kEnumerated, {"use_pdl", "use_pdl"}, "Use PDL (true, false)"}, {ArgumentTypeID::kScalar, {"problem-sizes"}, "MxNxK Problem sizes for the grouped GEMM, where a group is enclosed by `[]`. E.g. " @@ -236,6 +244,9 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( if (!file.good()) { throw std::runtime_error("Failed to open file: " + problem_file); } + // clear the problem sizes and 3x problem sizes from previous operation + problem_sizes.clear(); + problem_sizes_3x.clear(); for (std::string line; std::getline(file, line);) { std::istringstream iss(line); @@ -257,7 +268,7 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { // default value - this->cluster_m = 1; + this->cluster_m = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1; } if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { @@ -272,17 +283,17 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { // default value - this->cluster_m_fallback = 0; + this->cluster_m_fallback = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1; } if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { // default value - this->cluster_n_fallback = 0; + this->cluster_n_fallback = 1; } if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { // default value - this->cluster_k_fallback = 0; + this->cluster_k_fallback = 1; } this->mode = library::GemmUniversalMode::kGrouped; @@ -303,6 +314,31 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( return Status::kErrorInvalidProblem; } + if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) { + // default value + this->use_pdl = false; + } + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_a, "runtime_input_datatype_a", problem_space, problem)) { + // default value + this->runtime_input_datatype_a = cutlass::library::RuntimeDatatype::kStatic; + } + + if (!arg_as_RuntimeDatatype(this->runtime_input_datatype_b, "runtime_input_datatype_b", problem_space, problem)) { + // default value + this->runtime_input_datatype_b = cutlass::library::RuntimeDatatype::kStatic; + } + + if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) { + // default value + this->swizzle_size = 1; + } + + if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) { + // default value + this->raster_order = library::RasterOrder::kHeuristic; + } + if (!arg_as_scalar( this->alpha, operation_desc.gemm.element_epilogue, @@ -348,6 +384,19 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( .front(); } + // instantiation for exploration profiling + this->raster_orders = { + cutlass::library::RasterOrder::kAlongN, + cutlass::library::RasterOrder::kAlongM + }; + this->swizzle_sizes = {1, 2, 4, 8}; + this->preferred_clusters = { + {1, 1, 1}, {2, 1, 1}, {2, 2, 1}, {4, 1, 1}, {4, 2, 1}, {4, 4, 1}, {8, 2, 1} + }; + this->fallback_clusters = { + {1, 1, 1}, {2, 1, 1}, {2, 2, 1} + }; + return Status::kSuccess; } @@ -469,6 +518,13 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result( set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback); set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback); + set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(result, "swizzle_size", problem_space, swizzle_size); + set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl)); + + set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a)); + set_argument(result, "runtime_input_datatype_b", problem_space, library::to_string(runtime_input_datatype_b)); + set_argument( result, "alpha", @@ -482,6 +538,25 @@ void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result( library::lexical_cast(beta, operation_desc.gemm.element_epilogue)); } +void GroupedGemmOperationProfiler::update_result_( + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size +) { + set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(result, "swizzle_size", problem_space, swizzle_size); + + set_argument(result, "cluster_m", problem_space, preferred_cluster[0]); + set_argument(result, "cluster_n", problem_space, preferred_cluster[1]); + set_argument(result, "cluster_k", problem_space, preferred_cluster[2]); + set_argument(result, "cluster_m_fallback", problem_space, fallback_cluster[0]); + set_argument(result, "cluster_n_fallback", problem_space, fallback_cluster[1]); + set_argument(result, "cluster_k_fallback", problem_space, fallback_cluster[2]); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Extracts the problem dimensions @@ -506,7 +581,6 @@ Status GroupedGemmOperationProfiler::initialize_configuration( std::string blockwise_regex_string = sf_tuple + "(" + datatypes_regex + ")x(" + datatypes_regex + ")_" + sf_tuple + "(" + datatypes_regex + ")x(" + datatypes_regex + ")"; - if (std::string(operation_desc.gemm.name).find("bstensor") != std::string::npos) { is_block_scaled = true; @@ -538,6 +612,14 @@ Status GroupedGemmOperationProfiler::initialize_configuration( config.ldc = problem_.ldc.data(); config.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size; + gemm_workspace_.arguments.raster_order = problem_.raster_order; + + gemm_workspace_.arguments.runtime_input_datatype_a = problem_.runtime_input_datatype_a; + gemm_workspace_.arguments.runtime_input_datatype_b = problem_.runtime_input_datatype_b; + + gemm_workspace_.arguments.use_pdl = problem_.use_pdl; + initialize_result_(this->model_result_, options, operation_desc, problem_space); return status; @@ -1000,6 +1082,25 @@ bool GroupedGemmOperationProfiler::verify_cutlass( auto const& desc = static_cast(operation->description()); + cutlass::library::RuntimeDatatype runtime_datatype_a = gemm_workspace_.arguments.runtime_input_datatype_a; + cutlass::library::RuntimeDatatype runtime_datatype_b = gemm_workspace_.arguments.runtime_input_datatype_b; + + bool is_runtime_datatype_a = runtime_datatype_a != cutlass::library::RuntimeDatatype::kStatic; + bool is_runtime_datatype_b = runtime_datatype_b != cutlass::library::RuntimeDatatype::kStatic; + + assert(is_runtime_datatype_a == is_runtime_datatype_b && "runtime datatype should be both dynamic or static."); + + cutlass::library::NumericTypeID element_A = desc.gemm.A.element; + cutlass::library::NumericTypeID element_B = desc.gemm.B.element; + + if (is_runtime_datatype_a) { + element_A = cutlass::library::dynamic_datatype_to_id(runtime_datatype_a); + } + + if (is_runtime_datatype_b) { + element_B = cutlass::library::dynamic_datatype_to_id(runtime_datatype_b); + } + bool verification_status = verify_with_reference_( options, report, @@ -1007,8 +1108,8 @@ bool GroupedGemmOperationProfiler::verify_cutlass( operation, problem_space, problem, - desc.gemm.A.element, - desc.gemm.B.element); + element_A, + element_B); // Update disposition to worst case verification outcome among all // verification providers which are supported @@ -1442,13 +1543,23 @@ bool GroupedGemmOperationProfiler::profile( ProblemSpace::Problem const& problem) { if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { - results_.back().status = profile_cutlass_( - results_.back(), - options, - operation, - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data()); + if (options.profiling.enable_kernel_performance_search) { + std::cerr << "Exhaustive performance search is not available for Grouped GEMMs. " + << "Please use --enable-best-kernel-for-fixed-shape to profile a specific problem size " + << "with --problem-sizes or --problem-sizes-file.\n"; + } + else if (options.profiling.enable_best_kernel_for_fixed_shape) { + return profile_cutlass_for_fixed_shape_(options, operation, problem_space); + } + else { + results_.back().status = profile_cutlass_( + results_.back(), + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + } } return true; } @@ -1463,7 +1574,6 @@ Status GroupedGemmOperationProfiler::profile_cutlass_( void* arguments, void* host_workspace, void* device_workspace) { - library::Operation const* underlying_operation = operation; results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments); if (results_.back().status != Status::kSuccess) { @@ -1487,6 +1597,97 @@ Status GroupedGemmOperationProfiler::profile_cutlass_( ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Method to profile a CUTLASS Operation for the best configuration for a fixed shape +bool GroupedGemmOperationProfiler::profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space) { + library::GroupedGemmDescription const &operation_desc = + static_cast(operation->description()); + + auto min_cc = operation_desc.tile_description.minimum_compute_capability; + + bool is_dynamic_cluster_enabled = (min_cc >= 100); + + // Helper function to test validity of fallback cluster shapes and preferred cluster shapes. + auto is_valid_dynamic_cluster_shape = [](const std::array& preferred_cluster, const std::array& fallback_cluster) { + for (size_t i = 0; i < 3; ++i) { + if (preferred_cluster[i] % fallback_cluster[i] != 0) { + return false; + } + } + return true; + }; + + // Helper function to select the best performance number among a list. + auto select_best_candidate = [&](std::vector &candidates) { + assert(!candidates.empty() && "Candidates vector should not be empty"); + auto best_iter = std::max_element( + candidates.begin(), candidates.end(), + [](PerformanceResult const &a, PerformanceResult const &b) { + return a.gflops_per_sec() < b.gflops_per_sec(); + } + ); + assert(best_iter != candidates.end() && "No candidate found despite non-empty candidates vector"); + results_.push_back(std::move(*best_iter)); + }; + + std::vector candidates; + PerformanceResult result_base = results_.back(); + results_.pop_back(); + + bool dynamic_cluster = int64_t(operation_desc.tile_description.cluster_shape.m()) == 0 || + int64_t(operation_desc.tile_description.cluster_shape.n()) == 0 || + int64_t(operation_desc.tile_description.cluster_shape.k()) == 0; + + std::vector> preferred_clusters; + std::vector> fallback_clusters; + + // Only loop over built-in cluster shape lists for dynamic cluster kernels + // and for kernels that can leverage the dynamic cluster feature. + if (dynamic_cluster && is_dynamic_cluster_enabled) { + preferred_clusters = this->problem_.preferred_clusters; + fallback_clusters = this->problem_.fallback_clusters; + } + else { + preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}}; + fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}}; + } + + for (auto preferred_cluster : preferred_clusters) { + for (auto fallback_cluster : fallback_clusters) { + if (dynamic_cluster && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { + continue; + } + for (auto swizzle_size : this->problem_.swizzle_sizes) { + for (auto raster_order : this->problem_.raster_orders) { + PerformanceResult curr_result(result_base); + update_result_(curr_result, problem_space, raster_order, preferred_cluster, fallback_cluster, swizzle_size); + curr_result.status = profile_cutlass_( + curr_result, + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data() + ); + if (curr_result.status == Status::kSuccess) { // Only add valid results + candidates.push_back(curr_result); + } + }// for raster_order + }// for swizzle_size + }// for fallback_cluster + }// for preferred_clusters + + if (candidates.empty()) { + return false; + } + select_best_candidate(candidates); + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 4b27c64f..d711c8f1 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -141,9 +141,18 @@ Options::Device::Device(cutlass::CommandLine const &cmdline) { } } + // Permit overriding the sm_count + cmdline.get_cmd_line_argument("sm-count", sm_count, 0); } } +int Options::Device::get_sm_count(int device_index) const { + if (sm_count <= 0) { + return properties[device_index].multiProcessorCount; + } + return sm_count; +} + void Options::Device::print_usage(std::ostream &out) const { out << "Device:\n" @@ -185,7 +194,12 @@ void Options::Device::print_usage(std::ostream &out) const { << " --llc-capacity= " << " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line << " profiling phases cycle through different input tensors to induce" << end_of_line - << " capacity misses in the L2.\n\n"; + << " capacity misses in the L2.\n\n" + + << " --sm-count= " + << " Override the number of SMs. This is used to limit the number of " << end_of_line + << " during profiling. If this is set, profiling attempts to limit the sm_count " << end_of_line + << " to user-set value. This is not possible on all architectures and all kernel types. \n\n"; }