From c975e2ccbb2dbf13024568b37ffa3498ed0b3aed Mon Sep 17 00:00:00 2001 From: Aditya Atluri Date: Sat, 19 Nov 2022 06:02:15 -0800 Subject: [PATCH] releaase 2.11 (#703) --- .github/labeler.yml | 18 - .github/workflows/labeler.yml | 3 +- CHANGELOG.md | 27 +- CITATION.cff | 10 +- CMakeLists.txt | 11 +- CONTRIBUTORS.md | 88 +- README.md | 100 +- examples/03_visualize_layout/CMakeLists.txt | 2 - .../03_visualize_layout/register_layout.cu | 10 +- examples/05_batched_gemm/batched_gemm.cu | 4 +- .../ampere_gemm_operand_reduction_fusion.cu | 10 +- examples/30_wgrad_split_k/30_wgrad_split_k.cu | 6 +- .../gemm_with_epilogue_visitor.h | 6 - .../gemm_with_epilogue_visitor.h | 6 - examples/39_gemm_permute/gemm_permute.cu | 4 +- examples/40_cutlass_py/README.md | 245 +- examples/40_cutlass_py/conv2d.py | 317 +-- examples/40_cutlass_py/customizable/README.md | 192 ++ examples/40_cutlass_py/customizable/conv2d.py | 320 +++ examples/40_cutlass_py/customizable/gemm.py | 445 ++++ .../customizable/gemm_grouped.py | 287 +++ .../grouped_gemm_problem_size.csv | 0 examples/40_cutlass_py/gemm.py | 427 +--- examples/40_cutlass_py/gemm_grouped.py | 260 +- examples/40_cutlass_py/util.py | 60 + .../CMakeLists.txt | 44 + .../attention_scaling_coefs_updater.h | 31 + .../debug_utils.h | 31 + .../default_fmha_grouped.h | 284 +++ .../epilogue_pipelined.h | 0 .../epilogue_rescale_output.h | 31 + .../epilogue_thread_apply_logsumexp.h | 0 .../find_default_mma.h | 33 +- .../fmha_grouped.h | 839 +++++++ .../fmha_grouped_problem_visitor.h | 178 ++ ...fused_multihead_attention_fixed_seqlen.cu} | 128 +- ...ed_multihead_attention_variable_seqlen.cu} | 662 +++--- .../gemm/custom_mma.h | 31 + .../gemm/custom_mma_base.h | 0 .../gemm/custom_mma_multistage.h | 0 .../gemm/custom_mma_pipelined.h | 0 .../gemm_kernel_utils.h | 31 + .../epilogue_predicated_tile_iterator.h | 0 .../iterators/make_residual_last.h | 97 + ...cated_tile_access_iterator_residual_last.h | 0 .../predicated_tile_iterator_residual_last.h | 0 .../kernel_forward.h | 31 + .../mma_from_smem.h | 0 .../41_multi_head_attention/gemm_attention.h | 626 ----- .../gemm_grouped_with_softmax_visitor.h | 522 ---- .../CMakeLists.txt | 4 +- .../ampere_tensorop_group_conv.cu | 706 ++++++ .../iterators/make_residual_last.h | 66 - .../CMakeLists.txt | 6 +- .../ell_block_sparse_gemm.cu | 740 ++++++ .../44_multi_gemm_ir_and_codegen/README.md | 63 + .../44_multi_gemm_ir_and_codegen/config.json | 32 + .../default_bias_act_epilogue_tensor_op.h | 154 ++ ...ault_thread_map_tensor_op_for_fused_bias.h | 113 + .../threadblock/fused_bias_act_epilogue.h | 222 ++ .../output_tile_thread_map_for_fused_bias.h | 311 +++ ...sed_bias_act_fragment_iterator_tensor_op.h | 189 ++ ...r_op_fragment_iterator_without_output_op.h | 427 ++++ .../ir_gen/gen_all_code.py | 129 + .../ir_gen/gen_cmake.py | 131 + .../ir_gen/gen_customized_epilogue.py | 120 + .../ir_gen/gen_device.py | 477 ++++ .../ir_gen/gen_ir.py | 249 ++ .../ir_gen/gen_kernel.py | 476 ++++ .../ir_gen/gen_sample.py | 232 ++ .../ir_gen/gen_threadblock.py | 1013 ++++++++ .../ir_gen/gen_turing_and_volta.py | 456 ++++ .../ir_gen/gen_verify.py | 92 + .../ir_gen/generate.sh | 52 + .../ir_gen/helper.py | 135 ++ .../ir_gen/replace_fix_impl_header.py | 67 + .../44_multi_gemm_ir_and_codegen/leaky_bias.h | 292 +++ examples/44_multi_gemm_ir_and_codegen/utils.h | 94 + .../CMakeLists.txt | 2 +- .../device/dual_gemm.h | 0 .../dual_gemm.cu | 0 .../dual_gemm_run.h | 0 .../kernel/dual_gemm.h | 0 .../{43_dual_gemm => 45_dual_gemm}/test_run.h | 0 .../thread/left_silu_and_mul.h | 0 .../threadblock/dual_epilogue.h | 0 .../threadblock/dual_mma_base.h | 0 .../threadblock/dual_mma_multistage.h | 0 .../CMakeLists.txt | 36 + .../depthwise_simt_conv2dfprop.cu | 672 ++++++ examples/CMakeLists.txt | 8 +- include/cutlass/aligned_buffer.h | 6 +- include/cutlass/arch/arch.h | 4 + include/cutlass/arch/memory.h | 2 +- include/cutlass/arch/mma.h | 2 + include/cutlass/arch/mma_sm75.h | 13 +- include/cutlass/arch/mma_sm80.h | 1 + include/cutlass/arch/mma_sm90.h | 131 + include/cutlass/array.h | 1868 +++++++++++++++ include/cutlass/barrier.h | 201 ++ include/cutlass/bfloat16.h | 7 +- include/cutlass/blas3.h | 1 + include/cutlass/block_striped.h | 259 ++ include/cutlass/complex.h | 166 +- include/cutlass/conv/conv2d_problem_size.h | 35 +- include/cutlass/conv/convolution.h | 38 +- .../cutlass/conv/device/direct_convolution.h | 269 +++ .../conv/device/implicit_gemm_convolution.h | 76 +- .../conv/kernel/default_conv2d_group_fprop.h | 270 ++- .../conv/kernel/default_depthwise_fprop.h | 380 ++- .../cutlass/conv/kernel/direct_convolution.h | 505 ++++ include/cutlass/conv/thread/depthwise_mma.h | 325 +++ ...op_filter_tile_access_iterator_optimized.h | 6 +- .../depthwise_direct_conv_params.h | 230 ++ ...erator_direct_conv_fixed_stride_dilation.h | 314 +++ ...le_access_iterator_direct_conv_optimized.h | 291 +++ .../depthwise_fprop_direct_conv_multistage.h | 551 +++++ ...le_access_iterator_direct_conv_optimized.h | 261 ++ .../conv/threadblock/depthwise_mma_base.h | 229 ++ ...depthwise_mma_core_with_lane_access_size.h | 617 ++++- .../conv/threadblock/threadblock_swizzle.h | 24 +- .../cutlass/conv/warp/mma_depthwise_simt.h | 222 +- .../warp/mma_depthwise_simt_tile_iterator.h | 609 ++++- include/cutlass/coord.h | 13 +- include/cutlass/core_io.h | 5 +- include/cutlass/device_kernel.h | 17 + include/cutlass/epilogue/thread/activation.h | 77 +- .../linear_combination_bias_elementwise.h | 3 + .../thread/linear_combination_bias_relu.h | 3 + .../thread/linear_combination_leaky_relu.h | 1 + .../linear_combination_residual_block.h | 143 +- .../linear_combination_residual_block_v2.h | 197 -- .../threadblock/default_epilogue_simt.h | 98 + .../threadblock/default_epilogue_tensor_op.h | 135 ++ .../default_epilogue_with_broadcast_v2.h | 177 -- .../cutlass/epilogue/threadblock/epilogue.h | 442 ++-- .../threadblock/epilogue_base_streamk.h | 191 ++ .../epilogue/threadblock/epilogue_depthwise.h | 335 +++ .../threadblock/epilogue_direct_store.h | 1 - .../threadblock/epilogue_with_broadcast.h | 740 +++++- .../threadblock/epilogue_with_broadcast_v2.h | 847 ------- .../threadblock/epilogue_with_reduction.h | 4 +- .../threadblock/interleaved_epilogue.h | 103 +- .../threadblock/predicated_tile_iterator.h | 74 + .../predicated_tile_iterator_direct_conv.h | 445 ++++ .../predicated_tile_iterator_params.h | 84 + .../threadblock/predicated_tile_iterator_v2.h | 1023 -------- .../threadblock/shared_load_iterator.h | 5 + .../threadblock/shared_load_iterator_mixed.h | 12 + .../shared_load_iterator_pitch_liner.h | 194 ++ .../epilogue/warp/tile_iterator_simt.h | 293 +++ .../epilogue/warp/tile_iterator_tensor_op.h | 15 + .../warp/tile_iterator_tensor_op_mixed.h | 15 + .../warp/tile_iterator_volta_tensor_op.h | 10 + .../warp/tile_iterator_wmma_tensor_op.h | 6 + include/cutlass/fast_math.h | 65 +- include/cutlass/float8.h | 1213 ++++++++++ include/cutlass/floating_point_nvrtc.h | 65 + include/cutlass/functional.h | 2110 +---------------- include/cutlass/gemm/device/base_grouped.h | 4 +- .../gemm/device/default_gemm_configuration.h | 46 + include/cutlass/gemm/device/ell_gemm.h | 848 +++++++ include/cutlass/gemm/device/gemm_batched.h | 2 +- include/cutlass/gemm/device/gemm_universal.h | 5 + .../gemm/device/gemm_universal_adapter.h | 8 +- .../cutlass/gemm/device/gemm_universal_base.h | 521 ++-- .../device/gemm_universal_with_broadcast.h | 12 +- .../gemm/device/gemm_with_k_reduction.h | 3 +- .../cutlass/gemm/kernel/default_ell_gemm.h | 837 +++++++ include/cutlass/gemm/kernel/default_gemm.h | 71 + .../gemm/kernel/default_gemm_complex.h | 60 + .../gemm/kernel/default_gemm_universal.h | 53 +- .../kernel/default_gemm_with_broadcast_v2.h | 242 -- include/cutlass/gemm/kernel/default_rank_2k.h | 78 + .../gemm/kernel/default_rank_2k_complex.h | 164 ++ include/cutlass/gemm/kernel/default_rank_k.h | 62 + .../gemm/kernel/default_rank_k_complex.h | 134 ++ include/cutlass/gemm/kernel/default_symm.h | 96 +- .../gemm/kernel/default_symm_complex.h | 194 +- include/cutlass/gemm/kernel/default_trmm.h | 70 + .../gemm/kernel/default_trmm_complex.h | 68 + include/cutlass/gemm/kernel/ell_gemm.h | 830 +++++++ include/cutlass/gemm/kernel/gemm_grouped.h | 7 - .../kernel/gemm_grouped_problem_visitor.h | 17 +- .../gemm_grouped_softmax_mainloop_fusion.h | 7 - .../kernel/gemm_layernorm_mainloop_fusion.h | 187 +- .../cutlass/gemm/kernel/gemm_planar_complex.h | 179 +- .../gemm/kernel/gemm_planar_complex_array.h | 133 +- include/cutlass/gemm/kernel/gemm_universal.h | 198 +- .../gemm/kernel/gemm_universal_streamk.h | 1126 +++++++++ .../gemm/kernel/gemm_with_fused_epilogue.h | 931 +++++++- .../gemm/kernel/gemm_with_fused_epilogue_v2.h | 854 ------- .../gemm/kernel/gemm_with_k_reduction.h | 206 +- .../gemm/kernel/grouped_problem_visitor.h | 8 +- .../gemm/kernel/params_universal_base.h | 245 ++ include/cutlass/gemm/kernel/rank_2k_grouped.h | 7 - .../kernel/rank_2k_grouped_problem_visitor.h | 8 + include/cutlass/gemm/thread/mma_sm50.h | 183 +- .../gemm/threadblock/default_ell_mma.h | 734 ++++++ .../gemm/threadblock/default_mma_core_sm80.h | 3 + ...default_multistage_mma_complex_core_sm80.h | 108 +- .../gemm/threadblock/ell_mma_multistage.h | 642 +++++ .../gemm/threadblock/ell_mma_pipelined.h | 376 +++ .../cutlass/gemm/threadblock/index_remat.h | 107 + .../cutlass/gemm/threadblock/mma_multistage.h | 518 ++-- .../cutlass/gemm/threadblock/mma_pipelined.h | 262 +- .../gemm/threadblock/threadblock_swizzle.h | 58 +- .../threadblock/threadblock_swizzle_streamk.h | 778 ++++++ .../gemm/warp/default_mma_complex_tensor_op.h | 153 +- .../cutlass/gemm/warp/mma_complex_tensor_op.h | 312 ++- .../warp/mma_complex_tensor_op_fast_f32.h | 4 +- .../warp/mma_gaussian_complex_tensor_op.h | 285 ++- .../warp/mma_tensor_op_tile_iterator_wmma.h | 2 +- include/cutlass/half.h | 57 +- include/cutlass/integer_subbyte.h | 3 +- include/cutlass/layout/permute.h | 5 +- include/cutlass/numeric_conversion.h | 845 ++++++- include/cutlass/numeric_types.h | 1 + include/cutlass/platform/platform.h | 1 - include/cutlass/quaternion.h | 77 +- include/cutlass/semaphore.h | 1 - include/cutlass/tfloat32.h | 17 +- .../transform/pitch_linear_thread_map.h | 21 +- .../transform/threadblock/ell_iterator.h | 199 ++ .../ell_predicated_tile_access_iterator.h | 1350 +++++++++++ .../ell_predicated_tile_iterator.h | 1315 ++++++++++ .../predicated_tile_access_iterator.h | 77 +- .../threadblock/predicated_tile_iterator.h | 100 +- ...access_iterator_pitch_linear_direct_conv.h | 587 +++++ include/cutlass/uint128.h | 6 +- include/cutlass/wmma_array.h | 23 +- media/docs/efficient_gemm.md | 134 +- test/unit/common/filter_architecture.cpp | 1 + test/unit/conv/device/CMakeLists.txt | 4 +- test/unit/conv/device/conv2d_problems.h | 23 + test/unit/conv/device/conv2d_testbed.h | 28 +- .../conv/device/conv2d_testbed_interleaved.h | 8 +- .../device/conv2d_with_broadcast_testbed.h | 10 +- .../device/conv2d_with_reduction_testbed.h | 10 +- test/unit/conv/device/conv3d_testbed.h | 14 +- .../depthwise_conv2d_direct_conv_testbed.h | 473 ++++ ...v_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu | 426 ++++ ...n_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu | 522 ++++ ..._f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu} | 2 +- ...nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu | 149 ++ test/unit/core/CMakeLists.txt | 1 + test/unit/core/float8.cu | 103 + test/unit/core/numeric_conversion.cu | 315 ++- test/unit/epilogue/thread/activation.cu | 1 + test/unit/gemm/device/CMakeLists.txt | 61 +- .../gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu | 16 +- .../gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu | 32 +- ...mm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu | 12 +- .../gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu | 16 +- ...mm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu | 12 +- ...mm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu | 32 +- ...32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu | 2 +- ...32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu | 2 +- ...cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu | 198 ++ ...mm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu | 252 ++ ...cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu | 197 ++ ...mm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu | 305 +++ .../gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu | 39 + ..._f16n_f16t_tensor_op_f16_broadcast_sm80.cu | 440 ++++ .../gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu | 223 ++ .../gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu | 223 ++ .../gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu | 16 +- .../gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu | 16 +- .../hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu | 135 ++ .../her2k_cf64_cf64_tensor_op_f64_sm90.cu | 149 ++ .../herk_cf64_cf64_tensor_op_f64_sm90.cu | 93 + test/unit/gemm/device/multistage_testbed.h | 2 +- test/unit/gemm/device/simt_cgemm_nt_sm80.cu | 265 +++ test/unit/gemm/device/simt_cgemm_tn_sm80.cu | 269 +++ test/unit/gemm/device/simt_f8gemm_tn_sm50.cu | 87 + test/unit/gemm/device/simt_qgemm_nn_sm50.cu | 54 +- test/unit/gemm/device/simt_qgemm_nt_sm50.cu | 54 +- test/unit/gemm/device/simt_qgemm_tn_sm50.cu | 54 +- test/unit/gemm/device/simt_qgemm_tt_sm50.cu | 54 +- .../symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu | 133 ++ .../device/symm_f64_f64_tensor_op_f64_sm90.cu | 135 ++ .../syr2k_cf64_cf64_tensor_op_f64_sm90.cu | 150 ++ .../syr2k_f64_f64_tensor_op_f64_sm90.cu | 134 ++ .../syrk_cf64_cf64_tensor_op_f64_sm90.cu | 136 ++ .../device/syrk_f64_f64_tensor_op_f64_sm90.cu | 126 + test/unit/gemm/device/testbed.h | 67 +- test/unit/gemm/device/testbed_complex.h | 2 +- .../gemm/device/testbed_gemm_with_broadcast.h | 2 +- .../gemm/device/testbed_gemm_with_reduction.h | 2 +- .../gemm/device/testbed_grouped_scheduler.h | 3 +- test/unit/gemm/device/testbed_interleaved.h | 2 +- .../unit/gemm/device/testbed_planar_complex.h | 2 +- .../gemm/device/testbed_rank2k_universal.h | 2 +- .../gemm/device/testbed_rank_k_universal.h | 2 +- test/unit/gemm/device/testbed_sparse.h | 2 +- test/unit/gemm/device/testbed_splitk.h | 2 +- .../unit/gemm/device/testbed_symm_universal.h | 2 +- .../unit/gemm/device/testbed_trmm_universal.h | 2 +- test/unit/gemm/device/testbed_universal.h | 56 +- .../trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu | 137 ++ .../trmm_f64_f64_f64_tensor_op_f64_sm90.cu | 127 + .../mma_multistage_sparse_testbed.h | 2 +- test/unit/gemm/warp/CMakeLists.txt | 2 + test/unit/gemm/warp/gemm_complex_sm80.cu | 8 +- test/unit/gemm/warp/gemm_complex_sm90.cu | 334 +++ test/unit/gemm/warp/gemm_sm90.cu | 206 ++ .../include/cutlass/library/arch_mappings.h | 5 + .../library/include/cutlass/library/library.h | 3 + tools/library/scripts/conv2d_operation.py | 143 +- tools/library/scripts/generator.py | 781 +++++- tools/library/scripts/library.py | 76 +- tools/library/scripts/pycutlass/README.md | 11 +- .../gemm/gemm_universal_with_visitor.h | 90 +- .../pycutlass/src/cpp/include/swizzling.h | 16 +- .../pycutlass/src/pycutlass/c_types.py | 4 +- .../pycutlass/src/pycutlass/gemm_operation.py | 29 +- .../pycutlass/src/pycutlass/library.py | 4 - .../pycutlass/test/example/run_all_example.sh | 2 +- tools/library/src/conv2d_operation.h | 261 +- tools/library/src/conv3d_operation.h | 6 +- tools/library/src/gemm_operation.h | 30 +- tools/profiler/CMakeLists.txt | 1 + .../profiler/src/conv2d_operation_profiler.cu | 29 +- .../profiler/src/conv2d_operation_profiler.h | 13 +- tools/profiler/src/cudnn_helpers.cpp | 7 +- tools/profiler/src/cudnn_helpers.h | 2 +- .../include/cutlass/util/host_uncompress.h | 34 + .../util/reference/host/tensor_elementwise.h | 14 +- .../cutlass/util/reference/host/tensor_fill.h | 38 +- 329 files changed, 47332 insertions(+), 10607 deletions(-) delete mode 100644 .github/labeler.yml create mode 100644 examples/40_cutlass_py/customizable/README.md create mode 100644 examples/40_cutlass_py/customizable/conv2d.py create mode 100644 examples/40_cutlass_py/customizable/gemm.py create mode 100644 examples/40_cutlass_py/customizable/gemm_grouped.py rename examples/40_cutlass_py/{ => customizable}/grouped_gemm_problem_size.csv (100%) create mode 100644 examples/40_cutlass_py/util.py create mode 100644 examples/41_fused_multi_head_attention/CMakeLists.txt rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/attention_scaling_coefs_updater.h (90%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/debug_utils.h (77%) create mode 100644 examples/41_fused_multi_head_attention/default_fmha_grouped.h rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/epilogue_pipelined.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/epilogue_rescale_output.h (80%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/epilogue_thread_apply_logsumexp.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/find_default_mma.h (72%) create mode 100644 examples/41_fused_multi_head_attention/fmha_grouped.h create mode 100644 examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h rename examples/{42_fused_multi_head_attention/fused_multihead_attention.cu => 41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu} (94%) rename examples/{41_multi_head_attention/fused_multihead_attention.cu => 41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu} (62%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/gemm/custom_mma.h (54%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/gemm/custom_mma_base.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/gemm/custom_mma_multistage.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/gemm/custom_mma_pipelined.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/gemm_kernel_utils.h (84%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/iterators/epilogue_predicated_tile_iterator.h (100%) create mode 100644 examples/41_fused_multi_head_attention/iterators/make_residual_last.h rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/iterators/predicated_tile_access_iterator_residual_last.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/iterators/predicated_tile_iterator_residual_last.h (100%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/kernel_forward.h (95%) rename examples/{42_fused_multi_head_attention => 41_fused_multi_head_attention}/mma_from_smem.h (100%) delete mode 100644 examples/41_multi_head_attention/gemm_attention.h delete mode 100644 examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h rename examples/{42_fused_multi_head_attention => 42_ampere_tensorop_group_conv}/CMakeLists.txt (96%) create mode 100644 examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu delete mode 100644 examples/42_fused_multi_head_attention/iterators/make_residual_last.h rename examples/{41_multi_head_attention => 43_ell_block_sparse_gemm}/CMakeLists.txt (96%) create mode 100644 examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu create mode 100644 examples/44_multi_gemm_ir_and_codegen/README.md create mode 100644 examples/44_multi_gemm_ir_and_codegen/config.json create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py create mode 100755 examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py create mode 100644 examples/44_multi_gemm_ir_and_codegen/leaky_bias.h create mode 100644 examples/44_multi_gemm_ir_and_codegen/utils.h rename examples/{43_dual_gemm => 45_dual_gemm}/CMakeLists.txt (99%) rename examples/{43_dual_gemm => 45_dual_gemm}/device/dual_gemm.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/dual_gemm.cu (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/dual_gemm_run.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/kernel/dual_gemm.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/test_run.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/thread/left_silu_and_mul.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/threadblock/dual_epilogue.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/threadblock/dual_mma_base.h (100%) rename examples/{43_dual_gemm => 45_dual_gemm}/threadblock/dual_mma_multistage.h (100%) create mode 100644 examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt create mode 100644 examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu create mode 100644 include/cutlass/arch/mma_sm90.h create mode 100644 include/cutlass/barrier.h create mode 100644 include/cutlass/block_striped.h create mode 100644 include/cutlass/conv/device/direct_convolution.h create mode 100644 include/cutlass/conv/kernel/direct_convolution.h create mode 100644 include/cutlass/conv/thread/depthwise_mma.h create mode 100644 include/cutlass/conv/threadblock/depthwise_direct_conv_params.h create mode 100644 include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h create mode 100644 include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h create mode 100644 include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h create mode 100644 include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h create mode 100644 include/cutlass/conv/threadblock/depthwise_mma_base.h delete mode 100644 include/cutlass/epilogue/thread/linear_combination_residual_block_v2.h delete mode 100644 include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h create mode 100644 include/cutlass/epilogue/threadblock/epilogue_base_streamk.h create mode 100644 include/cutlass/epilogue/threadblock/epilogue_depthwise.h delete mode 100644 include/cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h create mode 100644 include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h delete mode 100644 include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h create mode 100644 include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h create mode 100644 include/cutlass/float8.h create mode 100644 include/cutlass/floating_point_nvrtc.h create mode 100644 include/cutlass/gemm/device/ell_gemm.h create mode 100644 include/cutlass/gemm/kernel/default_ell_gemm.h delete mode 100644 include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h create mode 100644 include/cutlass/gemm/kernel/ell_gemm.h create mode 100644 include/cutlass/gemm/kernel/gemm_universal_streamk.h delete mode 100644 include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h create mode 100644 include/cutlass/gemm/kernel/params_universal_base.h create mode 100644 include/cutlass/gemm/threadblock/default_ell_mma.h create mode 100644 include/cutlass/gemm/threadblock/ell_mma_multistage.h create mode 100644 include/cutlass/gemm/threadblock/ell_mma_pipelined.h create mode 100644 include/cutlass/gemm/threadblock/index_remat.h create mode 100644 include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h create mode 100644 include/cutlass/transform/threadblock/ell_iterator.h create mode 100644 include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h create mode 100644 include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h create mode 100644 include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h create mode 100644 test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h create mode 100644 test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu create mode 100644 test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu rename test/unit/conv/device/{depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu => depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu} (99%) create mode 100644 test/unit/core/float8.cu create mode 100644 test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu create mode 100644 test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu create mode 100644 test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/simt_cgemm_nt_sm80.cu create mode 100644 test/unit/gemm/device/simt_cgemm_tn_sm80.cu create mode 100644 test/unit/gemm/device/simt_f8gemm_tn_sm50.cu create mode 100644 test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu create mode 100644 test/unit/gemm/warp/gemm_complex_sm90.cu create mode 100644 test/unit/gemm/warp/gemm_sm90.cu mode change 100644 => 100755 tools/library/scripts/pycutlass/test/example/run_all_example.sh diff --git a/.github/labeler.yml b/.github/labeler.yml deleted file mode 100644 index a3a07b4e..00000000 --- a/.github/labeler.yml +++ /dev/null @@ -1,18 +0,0 @@ -# https://github.com/actions/labeler#common-examples - -examples: - - examples/** - -source: - - cmake/** - - include/cutlass/** - -documentation: - - docs/** - - media/** - -testing: - - test/** - -tooling: - - tools/** diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index fc949114..23956a02 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -5,8 +5,7 @@ on: jobs: triage: runs-on: ubuntu-latest - permissions: read-all|write-all steps: - - uses: actions/labeler@master + - uses: actions/labeler@main with: repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/CHANGELOG.md b/CHANGELOG.md index 04af9524..ba03351a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # NVIDIA CUTLASS Changelog +## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19) +* Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. +* [Fused multi-head attention Kernel](/examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel. +* [Dual GEMM](/examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency. +* Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. +* [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions. +* [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary. +* Optimized [Group Conv](/examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N. +* [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added + * [kOptimized](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. + * The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration. + * [kFixedStrideDilation](/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded. + * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. +* [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/). +* [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115). +* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). + +* **Deprecation announcement:** CUTLASS plans to deprecate the following: + * Maxwell and Pascal GPU architectures + * Ubuntu 16.04 + * CUDA 10.2 + ## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) * [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. * Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too. @@ -16,11 +38,6 @@ * Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads) * Updates and bugfixes from the community (thanks!) -* **Deprecation announcement:** CUTLASS plans to deprecate the following: - * Maxwell and Pascal GPU architectures - * Ubuntu 16.04 - * CUDA 10.2 - ## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21) * [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment diff --git a/CITATION.cff b/CITATION.cff index 6dbdb1a8..7ae2b4b1 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -73,10 +73,10 @@ abstract: >- keywords: - 'cutlass, tensor cores, cuda' license: BSD-3-Clause -license-url: https://github.com/NVIDIA/cutlass/blob/v2.10.0/LICENSE.txt -version: '2.10.0' -date-released: '2022-09-15' +license-url: https://github.com/NVIDIA/cutlass/blob/v2.11.0/LICENSE.txt +version: '2.11.0' +date-released: '2022-11-19' identifiers: - type: url - value: "https://github.com/NVIDIA/cutlass/tree/v2.10.0" - description: The GitHub release URL of tag 2.10.0 + value: "https://github.com/NVIDIA/cutlass/tree/v2.11.0" + description: The GitHub release URL of tag 2.11.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index 30e261c2..b6eddde6 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.10.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.11.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 10.2) @@ -87,6 +87,7 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library") set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler") +set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Proformance") if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}}) @@ -122,6 +123,9 @@ endif() if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86) endif() +if (NOT CUDA_VERSION VERSION_LESS 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90) +endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") @@ -569,6 +573,9 @@ install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest) ################################################################################ +set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "cuBLAS usage for tests") +set(CUTLASS_ENABLE_CUDNN OFF CACHE BOOL "cuDNN usage for tests") + include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake) if (CUTLASS_ENABLE_CUBLAS) @@ -732,7 +739,7 @@ if (CUTLASS_ENABLE_TOOLS) add_subdirectory(tools) if (CUTLASS_ENABLE_PROFILER) add_dependencies(test_all test_profiler) - endif() + endif() endif() if (CUTLASS_ENABLE_EXAMPLES) add_subdirectory(examples) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 576f5ae1..21357b5f 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -7,10 +7,10 @@ This is the official list of CUTLASS developers and contributors. ## DEVELOPERS -Andrew Kerr -Haicheng Wu -Manish Gupta -Dustyn Blasig +Andrew Kerr +Haicheng Wu +Manish Gupta +Dustyn Blasig Pradeep Ramani Cris Cecka Vijay Thakkar @@ -20,52 +20,50 @@ Ethan Yan Zhaodong Chen Jack Kosaian Yujia Zhai -Naila Farooqui -Piotr Majcher -Paul Springer -Jin Wang -Chinmay Talegaonkar -Shang Zhang -Scott Yokim -Markus Hohnerbach -Aditya Atluri -David Tanner -Manikandan Ananth +Naila Farooqui +Piotr Majcher +Paul Springer +Jin Wang +Chinmay Talegaonkar +Shang Zhang +Scott Yokim +Markus Hohnerbach +Aditya Atluri +David Tanner +Manikandan Ananth ## CUTLASS Product Manager Matthew Nicely ## CONTRIBUTORS -Timothy Costa -Julien Demouth -Brian Fahs -Michael Goldfarb -Mostafa Hagog -Fei Hu -Alan Kaatz -Tina Li -Timmy Liu -Duane Merrill -Kevin Siu -Markus Tavenrath -John Tran -Vicki Wang -Junkai Wu -Fung Xie -Albert Xu -Jack Yang -Xiuxia Zhang -Nick Zhao +Timothy Costa +Julien Demouth +Brian Fahs +Michael Goldfarb +Mostafa Hagog +Fei Hu +Alan Kaatz +Tina Li +Timmy Liu +Duane Merrill +Kevin Siu +Markus Tavenrath +John Tran +Vicki Wang +Junkai Wu +Fung Xie +Albert Xu +Jack Yang +Xiuxia Zhang +Nick Zhao ## ACKNOWLEDGEMENTS -Girish Bharambe -Luke Durant -Olivier Giroux -Stephen Jones -Rishkul Kulkarni -Bryce Lelbach -Joel McCormack -Kyrylo Perelygin - - +Girish Bharambe +Luke Durant +Olivier Giroux +Stephen Jones +Rishkul Kulkarni +Bryce Lelbach +Joel McCormack +Kyrylo Perelygin diff --git a/README.md b/README.md index 1fe73dbd..0fecdee2 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.10 +# CUTLASS 2.11 -_CUTLASS 2.10 - August 2022_ +_CUTLASS 2.11 - November 2022_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) and related computations at all levels @@ -36,21 +36,21 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. -# What's New in CUTLASS 2.10 +# What's New in CUTLASS 2.11 -CUTLASS 2.10 is an update to CUTLASS adding: -- [CUTLASS Python](/examples/40_cutlass_py) now supports GEMM, Convolution and Grouped GEMM for different data types as well as different epilogue flavors. -- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel. It can move some scheduling into the host side if applicable. -- Optimizations for [GEMM+Softmax](examples/35_gemm_softmax). -- [Grouped GEMM for Multihead Attention](examples/41_multi_head_attention) is a general MHA that does not require equal sequence length in every GEMM. -- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) can fuse the layernorm into GEMMs before and after. -- [GEMM Epilogue Permutation Fusion](examples/39_gemm_permute) can permute the GEMM output before storing. -- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. -- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. -- Standalone [Layernorm](/tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](/tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels. -- [Back-to-back GEMM](examples/13_two_tensor_op_fusion) enhancements. -- Updates and bugfixes from the community (thanks!) -- **Deprecation announcement:** CUTLASS plans to deprecate the following: +CUTLASS 2.11 is an update to CUTLASS adding: +- Stream-K, which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. +- [Fused multi-head attention kernel](/examples/41_fused_multi_head_attention). It has two variants: one for fixed sequence lengths, and another for variable sequence lengths. +- [Dual GEMM](/examples/45_dual_gemm). It can run two GEMMs that share the same left input matrix in one kernel. +- Hopper improves [double precision matrix multiplication](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. +- [BLAS3](/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions. +- [ELL Block Sparse GEMM](/examples/43_ell_block_sparse_gemm). +- [Optimized Group Conv](/examples/42_ampere_tensorop_group_conv). +- [Optimized DepthWise Conv](/examples/46_depthwise_simt_conv2dfprop). +- [Scripts](/examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. +- [FP8 data type definition](/include/cutlass/float8.h) and [conversion routines](/include/cutlass/numeric_conversion.h#L1274-2115). +- Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). +- **Deprecation announcement:** CUTLASS plans to deprecate the following in the next major release: - Maxwell and Pascal GPU architectures - Ubuntu 16.04 - CUDA 10.2 @@ -80,10 +80,11 @@ as shown in the above figure. Tensor Core operations are still implemented usin # Compatibility -CUTLASS requires a C++11 host compiler and -performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit). -It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5. +CUTLASS requires a C++11 host compiler and performs best when built with the [**CUDA 11.8 Toolkit**](https://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 11.x. + +## Operating Systems We have tested the following environments. |**Operating System** | **Compiler** | @@ -93,11 +94,12 @@ We have tested the following environments. | | Microsoft Visual Studio 2019| | Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | -| Ubuntu 21.04 | GCC 11.2.0 | +| Ubuntu 22.04 | GCC 11.2.0 | Additionally, CUTLASS may be built with clang. See [these instructions](media/docs/quickstart.md#clang) for more details. +## Hardware CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. @@ -110,9 +112,7 @@ any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. |NVIDIA A100|8.0|11.0|11.0| |NVIDIA A10 |8.6|11.1|11.1| |NVIDIA GeForce 3090|8.6|11.1|11.1| - -For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit) -for best performance. +|NVIDIA H100 PCIe|9.0|11.8|11.8| # Documentation @@ -133,9 +133,16 @@ CUTLASS is described in the following documents and the accompanying - [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application - [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development +# Resources We have also described the structure of an efficient GEMM in our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). + - [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/) + - [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/) + - [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/) + - [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/) + - [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/) + # Building CUTLASS CUTLASS is a header-only template library and does not need to be built to be used by other @@ -199,6 +206,8 @@ include/ # client applications should target this directory conv/ # code specialized for convolution + epilogue/ # code specialized for the epilogue of gemm/convolution + gemm/ # code specialized for general matrix product computations layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory @@ -206,6 +215,8 @@ include/ # client applications should target this directory platform/ # CUDA-capable Standard Library components reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model + + thread/ # simt code that can be performed within a CUDA thread transform/ # code specialized for layout, type, and domain transformations @@ -216,49 +227,6 @@ include/ # client applications should target this directory [CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations. -``` -examples/ - 00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs - - 01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors - - 02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents - - 03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS - - 04_tile_iterator/ # example demonstrating an iterator over tiles in memory - - 05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation - - 06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel - - 07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores - - 08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores - - 09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores - - 10_planar_complex/ # example demonstrating planar complex GEMM kernels - - 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes - - 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu - - 13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel - - 22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores - - 31_basic_syrk # example demonstrating Symmetric Rank-K update - - 32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication - - 33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation - - 35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores - - 40_cutlass_py # example demonstrating CUTLASS with CUDA Python -``` - ### Tools ``` diff --git a/examples/03_visualize_layout/CMakeLists.txt b/examples/03_visualize_layout/CMakeLists.txt index 27c38249..bc41312f 100644 --- a/examples/03_visualize_layout/CMakeLists.txt +++ b/examples/03_visualize_layout/CMakeLists.txt @@ -29,7 +29,6 @@ set(TEST_COMMAND_00 RowMajor --extent=16,16) -set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4) cutlass_example_add_executable( 03_visualize_layout @@ -37,6 +36,5 @@ cutlass_example_add_executable( register_layout.cu TEST_COMMAND_OPTIONS TEST_COMMAND_00 - TEST_COMMAND_01 ) diff --git a/examples/03_visualize_layout/register_layout.cu b/examples/03_visualize_layout/register_layout.cu index 060abe35..323556bb 100644 --- a/examples/03_visualize_layout/register_layout.cu +++ b/examples/03_visualize_layout/register_layout.cu @@ -64,15 +64,15 @@ void RegisterLayouts(std::map // All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling // layout implementation with different templates. // - // BMMA 88128 Interleaved-256 - // BMMA 168256 Interleaved-256 + // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256 {"TensorOpMultiplicand<1,256>", new VisualizeLayout>}, - // BMMA 88128 TN kblock512 - // BMMA 168256 TN kblock512 + // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512 {"TensorOpMultiplicand<1,512>", new VisualizeLayout>}, - // BMMA 168256 TN kblock1024 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024 {"TensorOpMultiplicand<1,1024>", new VisualizeLayout>}, // Integer matrix multiply.int4 8832 Interleaved-64 diff --git a/examples/05_batched_gemm/batched_gemm.cu b/examples/05_batched_gemm/batched_gemm.cu index 2ce552c7..e1638e70 100644 --- a/examples/05_batched_gemm/batched_gemm.cu +++ b/examples/05_batched_gemm/batched_gemm.cu @@ -81,7 +81,7 @@ matrix A can be seen as --------------------------------------- batch 0 | batch 1 , where batch size is 2, M is 6 and K is 2 -The stride (batch_stride_B) between the first element of two batches is lda * k +The stride (batch_stride_A) between the first element of two batches is lda * k matrix B can be seen as ----------------------------- @@ -94,7 +94,7 @@ matrix B can be seen as (1,1,0) | (1,1,1) | (1,1,2) | ----------------------------- , where the batch size is 2, N is 3 and K is 2 -The stride (batch_stride_C) between the first element of two batches is k +The stride (batch_stride_B) between the first element of two batches is k */ diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 28963203..0d0e6477 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -426,7 +426,7 @@ Result profile(Options const &options) { // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{ + typename Gemm::Arguments arguments( mode, options.problem_size, batch_count, @@ -445,8 +445,7 @@ Result profile(Options const &options) { tensor_b.layout().stride(0), tensor_c.layout().stride(0), tensor_d.layout().stride(0), - tensor_reduction.layout().stride(0) - }; + tensor_reduction.layout().stride(0)); // Instantiate CUTLASS kernel depending on templates Gemm gemm_op; @@ -515,15 +514,14 @@ Result profile(Options const &options) { cutlass::TensorRef tensor_nullptr_tensorref(nullptr, splitk_vector_layout); - typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{ + typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments( cutlass::MatrixCoord(1, reduce_vector_length), batch_count, size_t(reduce_vector_length), workspace_vector_tensorref, tensor_reduction_tensorref, tensor_nullptr_tensorref, - {1.0f, 0.0f} - }; + {1.0f, 0.0f}); ReduceVectorSplitK reduce_vector_splitk_op; diff --git a/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/examples/30_wgrad_split_k/30_wgrad_split_k.cu index 0c7f32f4..6c56509c 100644 --- a/examples/30_wgrad_split_k/30_wgrad_split_k.cu +++ b/examples/30_wgrad_split_k/30_wgrad_split_k.cu @@ -531,17 +531,17 @@ Result profile_convolution(Options const &options) { // Reduction input { reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, // Destination { tensor_d.device_data(), - ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, // Source { tensor_c.device_data(), - ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, {options.alpha, options.beta} ); diff --git a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h index 263ed75a..da04df55 100644 --- a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +++ b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h @@ -367,12 +367,6 @@ public: return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - #define SPLIT_K_ENABLED 1 /// Executes one GEMM diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h index 0323139c..4ac517aa 100644 --- a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h @@ -309,12 +309,6 @@ public: return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/examples/39_gemm_permute/gemm_permute.cu b/examples/39_gemm_permute/gemm_permute.cu index 1e32b918..81d5ce27 100644 --- a/examples/39_gemm_permute/gemm_permute.cu +++ b/examples/39_gemm_permute/gemm_permute.cu @@ -690,7 +690,7 @@ public: // Initialize the GEMM object GemmBatched gemm; - result.status = gemm.initialize(arguments); + result.status = gemm.initialize(arguments, nullptr); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; @@ -854,7 +854,7 @@ public: // Initialize the GEMM object GemmPermute gemm_normal; - result.status = gemm_normal.initialize(arguments); + result.status = gemm_normal.initialize(arguments, nullptr); if (result.status != cutlass::Status::kSuccess) { std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md index 0c4b4f1e..c4556d1c 100644 --- a/examples/40_cutlass_py/README.md +++ b/examples/40_cutlass_py/README.md @@ -1,230 +1,23 @@ -# CUTLASS Python Interface Example +# CUTLASS Python Interface Examples +This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples: +* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations +* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel +>>>>>>> Add simplified examples -## Using Docker -You can run the PyCUTLASS on NGC PyTorch container. +## Setting up the Python interface +Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API. + +## Running examples +Each of the basic examples can be run as follows: ```shell -docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3 -``` -PyCUTLASS requires additional dependency Boost C++ library, which can be installed with -```bash -apt-get update -apt-get -y install libboost-all-dev +# Run the GEMM example +python gemm.py + +# Run the Conv2d example +python conv2d.py + +# Run the grouped GEMM example +python gemm_grouped.py ``` - -## Install the Python Interface -The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables: -* `CUTLASS_PATH`: the root directory of CUTLASS -* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed - -After setting these two environment variables, PyCUTLASS can be installed with -```shell -cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh -``` -*** - -## Troubleshooting - -### Issue 1: permission denied -Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission. - -### Issue 2: rmm: module not found -PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps: -```shell -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm - -cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python -python setup.py build_ext --inplace -python setup.py install -``` -To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues. - -*** -For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message. -## GEMM Examples -The GEMM examples use numpy to create input tensors and verify the results. -### GEMM F64 Example -Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16 -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -``` - -### GEMM F32 Example -Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32 -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -``` -Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 -``` - -### GEMM F16 Example -Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32 -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -``` -Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 -``` - -### GEMM BF16 Example -Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel -```python -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -``` - -### GEMM Int8 Example -Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128 -```python -python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 -``` - -### Batched & Array GEMM -Example 1: Batched GEMM -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 -``` -Example 2: Array GEMM -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 -``` -*** -## GEMM Grouped Examples -The GEMM Grouped examples use numpy to create input tensors and verify the results. - -Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule -```python -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device -``` -Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule -```python -python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host -``` -Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule -```python -python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device -``` -Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule -```python -python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device -``` -*** -## Conv2d Example -The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences). -### Conv2d F32 Fprop -Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 -```python -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2 -```python -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 -``` -Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 -```python -python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 -``` -### Conv2d F32 Wgrad -Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1 -```python -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 -```python -python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -### Conv2d F32 Dgrad -Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 -```python -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` - -### Conv2d F16 Fprop -Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32 -```python -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2 -```python -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8 -```python -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` -Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4 -```python -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 -``` - -## Epilogue -### Bias -To replace C with a bias vector, add `-bias` flag. -### Activation function -Example 1: ReLU -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu -``` -Example 2: leaky ReLU -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 -``` -Example 3: tanh (alpha=0 to avoid saturation) -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh -``` -Example 4: sigmoid -```python -python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid -``` -Example 5: SiLU -```python -python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu -``` -Example 6: HardSwish -```python -python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish -``` -Example 7: GELU -```python -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu -``` -### Epilogue Visitor Tree -Example 1: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 2: -```python -python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 3: -```python -python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 4: -```python -python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -``` -Example 5: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 -``` -Example 6: -```python -python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 -``` +To run the customizable examples, refer to the README in the [customizable](customizable) directory. diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py index 8be1f4d9..50854d38 100644 --- a/examples/40_cutlass_py/conv2d.py +++ b/examples/40_cutlass_py/conv2d.py @@ -29,290 +29,133 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ -import pycutlass -from pycutlass import * -from pycutlass.conv2d_operation import * -from pycutlass.utils import reference_model -import torch.nn.functional as F +""" +Basic example of using the CUTLASS Python interface to run a 2d convolution +""" import argparse +import torch +import numpy as np +import sys -# parse the arguments -parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python") - -# Operation description -# math instruction description -parser.add_argument("-i", "--instruction_shape", - default=[1, 1, 1], nargs=3, type=int, - help="This option describes the size of MMA op") -parser.add_argument("-ta", "--element_a", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor A') -parser.add_argument("-tb", "--element_b", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor B') -parser.add_argument("-tc", "--element_c", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor C and output tensor D') -parser.add_argument("-tacc", "--element_acc", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of accumulator') -parser.add_argument('-m', "--math", default="multiply_add", - type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, - choices=["Simt", 'TensorOp'], - help='This option describes whether you want to use tensor \ - cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') -# tile description -parser.add_argument("-b", "--threadblock_shape", - default=[128, 128, 8], nargs=3, type=int, - help="This option describes the tile size a thread block with compute") -parser.add_argument("-s", "--stages", default=4, - type=int, help="Number of pipelines you want to use") -parser.add_argument("-w", "--warp_count", default=[ - 4, 2, 1], nargs=3, type=int, - help="This option describes the number of warps along M, N, and K of the threadblock") -parser.add_argument("-cc", "--compute_capability", default=80, - type=int, help="This option describes CUDA SM architecture number") -# A -parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ - "TensorNHWC", "TensorNC32HW32"], - help="Memory layout of input tensor A") -parser.add_argument('-aa', '--alignment_a', default=1, - type=int, help="Memory alignement of input tensor A") -# B -parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ - "TensorNHWC", "TensorC32RSK32"], - help="Memory layout of input tensor B") -parser.add_argument('-ab', '--alignment_b', default=1, - type=int, help="Memory alignment of input tensor B") -# C -parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ - "TensorNHWC", "TensorNC32HW32"], - help="Memory layout of input tensor C and output tensor D") -parser.add_argument('-ac', '--alignment_c', default=1, - type=int, help="Memory alignment of input tensor C and output tensor D") -# epilogue -parser.add_argument("-te", "--element_epilogue", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16'], - help='Data type of computation in the epilogue') -parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", - type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], - help="This option describes the epilogue part of the kernel") -# swizzling -parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ - "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", - "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", - "StridedDgradHorizontalSwizzle"], - help="This option describes how thread blocks are scheduled on GPU") -# conv related -parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], - help="The type of convolution: forward propagation (fprop), \ - gradient of activation (dgrad), gradient of weight (wgrad)") -parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], - ) -parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, - choices=["analytic", "optimized", "fixed_channels", "few_channels"], - help="This option describes iterator algorithm") - -# arguments -parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], - help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ - Parallel is used for parallel splitK.") -parser.add_argument('-k', '--split_k_slices', default=1, - type=int, help="Number of split-k partitions. (default 1)") -parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") -parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") -parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") -parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") -parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") -parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") -parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") -parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") -# Activation function -parser.add_argument("-activ", "--activation_function", default="identity", - choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") -parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, - help="addition arguments for activation") +import cutlass +import pycutlass +from pycutlass import * +import util -parser.add_argument('--print_cuda', action="store_true", - help="print the underlying CUDA kernel") +parser = argparse.ArgumentParser( + description=("Launch a 2d convolution kernel from Python. " + "See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation.")) +parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution") +parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution") +parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution") +parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution") +parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution") +parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution") +parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") try: args = parser.parse_args() except: sys.exit(0) -pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +# Check that the device is of a sufficient compute capability +cc = util.get_device_cc() +assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70." + +alignment = 1 np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) -math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment) +B = TensorDescription(cutlass.float16, cutlass.TensorNHWC, alignment) +C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment) +element_acc = cutlass.float32 +element_epilogue = cutlass.float32 math_inst = MathInstruction( - args.instruction_shape, element_a, element_b, - element_acc, opclass, math_operation + [16, 8, 8], # Shape of the Tensor Core instruction + A.element, B.element, element_acc, + cutlass.OpClass.TensorOp, + MathOperation.multiply_add ) tile_description = TileDescription( - args.threadblock_shape, args.stages, args.warp_count, + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) - -A = TensorDescription( - element_a, layout_a, args.alignment_a -) - -B = TensorDescription( - element_b, layout_b, args.alignment_b -) - -C = TensorDescription( - element_c, layout_c, args.alignment_c -) - -element_epilogue = getattr(cutlass, args.element_epilogue) -if (args.activation_function == "identity" - or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): - # - epilogue_functor = getattr(pycutlass, args.epilogue_functor)( - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -else: - epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( - getattr(pycutlass, args.activation_function)(element_epilogue), - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - -iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm) -swizzling_functor = getattr(cutlass, args.swizzling_functor) -stride_support = getattr(StrideSupport, args.stride_support) -conv_kind = getattr(cutlass.conv.Operator, args.conv_kind) +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) operation = Conv2dOperation( - conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, - arch=args.compute_capability, tile_description=tile_description, - A=A, B=B, C=C, stride_support=stride_support, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + conv_kind=cutlass.conv.Operator.fprop, + iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, stride_support=StrideSupport.Strided, + epilogue_functor=epilogue_functor ) if args.print_cuda: print(operation.rt_module.emit()) -operations = [operation,] - -if args.split_k_mode == "Parallel" and args.split_k_slices > 1: - if (args.activation_function == "identity"): - epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - else: - epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( - getattr(pycutlass, args.activation_function)(element_epilogue), - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - reduction_operation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * C.alignment), - C=C, element_accumulator=element_acc, - element_compute=element_epilogue, - epilogue_functor=epilogue_functor_reduction, - count=C.alignment - ) - operations.append(reduction_operation) +operations = [operation, ] +# Compile the operation pycutlass.compiler.add_module(operations) +# Randomly initialize tensors + problem_size = cutlass.conv.Conv2dProblemSize( - cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), - cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), - cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), - cutlass.MatrixCoord(args.stride[0], args.stride[1]), - cutlass.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass.Tensor4DCoord(args.n, args.h, args.c, args.w), + cutlass.Tensor4DCoord(args.k, args.r, args.s, args.c), + cutlass.Tensor4DCoord(0, 0, 0, 0), # Padding + cutlass.MatrixCoord(1, 1), # Strides + cutlass.MatrixCoord(1, 1), # Dilation cutlass.conv.Mode.cross_correlation, - args.split_k_slices, 1 + 1, # Split k slices + 1 # Groups ) +tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) +tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) +tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) -# User-provide inputs -tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size( - conv_kind, problem_size -) -tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size( - conv_kind, problem_size -) -if args.bias: - tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent( - conv_kind, problem_size - ).at(3) -else: - tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( - conv_kind, problem_size - ) +tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda") -tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size( - conv_kind, problem_size - ) - -if args.element_a != "int8": - tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) -else: - tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) - -if args.element_b != "int8": - tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) -else: - tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) - -if args.element_c != "int8": - tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) -else: - tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) - -tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") +alpha = 1. +beta = 0. arguments = Conv2dArguments( - operation=operation, problem_size=problem_size, A=tensor_A, - B=tensor_B, C=tensor_C, D=tensor_D, - output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), - split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode), - split_k_slices=problem_size.split_k_slices + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta) ) -if args.split_k_mode == "Parallel" and args.split_k_slices > 1: - implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) - reduction_arguments = ReductionArguments( - reduction_operation, - problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], - partitions=problem_size.split_k_slices, - workspace=arguments.ptr_D, - destination=tensor_D, - source=tensor_C, - output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), - bias = arguments.bias - ) - +# Run the operation operation.run(arguments) +arguments.sync() -if args.split_k_mode == "Parallel" and args.split_k_slices > 1: - reduction_operation.run(reduction_arguments) - reduction_arguments.sync() -else: - arguments.sync() - -reference_model = Conv2dReferenceModule(A, B, C, conv_kind) - -tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) -if (args.activation_function != "identity"): - tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) +# Run the host reference module and compare to the CUTLASS result +reference = Conv2dReferenceModule(A, B, C, operation.conv_kind) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) try: assert torch.equal(tensor_D, tensor_D_ref) except: assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) + print("Passed.") diff --git a/examples/40_cutlass_py/customizable/README.md b/examples/40_cutlass_py/customizable/README.md new file mode 100644 index 00000000..cd25c69f --- /dev/null +++ b/examples/40_cutlass_py/customizable/README.md @@ -0,0 +1,192 @@ +# Customizable Python Interface Examples +This directory contains examples of using the CUTLASS Python interface with a variety of configurations for kernels. + +For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message. + +## GEMM Examples +The GEMM examples use numpy to create input tensors and verify the results. +### GEMM F64 Example +Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16 +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` + +### GEMM F32 Example +Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 +``` +Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 +``` + +### GEMM F16 Example +Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` +Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 +``` + +### GEMM BF16 Example +Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 +``` + +### GEMM Int8 Example +Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128 +```python +python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 +``` + +### Batched & Array GEMM +Example 1: Batched GEMM +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 +``` +Example 2: Array GEMM +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 +``` +*** +## GEMM Grouped Examples +The GEMM Grouped examples use numpy to create input tensors and verify the results. + +Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device +``` +Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host +``` +Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule +```python +python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +*** +## Conv2d Example +The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/get-started/locally/). +### Conv2d F32 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +### Conv2d F32 Wgrad +Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +### Conv2d F32 Dgrad +Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` + +### Conv2d F16 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` + +## Epilogue +### Bias +To replace C with a bias vector, add `-bias` flag. +### Activation function +Example 1: ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu +``` +Example 2: leaky ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 +``` +Example 3: tanh (alpha=0 to avoid saturation) +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh +``` +Example 4: sigmoid +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid +``` +Example 5: SiLU +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu +``` +Example 6: HardSwish +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish +``` +Example 7: GELU +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu +``` +### Epilogue Visitor Tree +Example 1: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -epv ColumnBroadcast -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 3: +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 4: +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnReduction -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 5: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv RowReduction -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 +``` +Example 6: +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -epv ColumnBroadcast -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 3 +``` diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py new file mode 100644 index 00000000..365ce3e5 --- /dev/null +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -0,0 +1,320 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import numpy as np +import pycutlass +from pycutlass import * +from pycutlass.conv2d_operation import * +from pycutlass.utils import reference_model +import sys +import torch.nn.functional as F + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], + help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorC32RSK32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], + help='Data type of computation in the epilogue') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", + "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", + "StridedDgradHorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# conv related +parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], + help="The type of convolution: forward propagation (fprop), \ + gradient of activation (dgrad), gradient of weight (wgrad)") +parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], + ) +parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, + choices=["analytic", "optimized", "fixed_channels", "few_channels"], + help="This option describes iterator algorithm") + +# arguments +parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], + help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ + Parallel is used for parallel splitK.") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") +parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") +parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") +parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") +parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") + + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +if (args.activation_function == "identity" + or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass, args.swizzling_functor) +stride_support = getattr(StrideSupport, args.stride_support) +conv_kind = getattr(cutlass.conv.Operator, args.conv_kind) + +operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, stride_support=stride_support, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation,] + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + reduction_operation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +problem_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass.MatrixCoord(args.stride[0], args.stride[1]), + cutlass.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass.conv.Mode.cross_correlation, + args.split_k_slices, 1 +) + + +# User-provide inputs +tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size( + conv_kind, problem_size +) +tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size( + conv_kind, problem_size +) +if args.bias: + tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ).at(3) +else: + tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +tensor_D_size = cutlass.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +if args.element_a != "int8": + tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) + +if args.element_b != "int8": + tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) + +if args.element_c != "int8": + tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) + +tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode), + split_k_slices=problem_size.split_k_slices +) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +reference_model = Conv2dReferenceModule(A, B, C, conv_kind) + +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) +if (args.activation_function != "identity"): + tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py new file mode 100644 index 00000000..dd6a7a4a --- /dev/null +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -0,0 +1,445 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import numpy as np +import pycutlass +from pycutlass import * +import cutlass +from bfloat16 import bfloat16 +import sys + +import argparse + + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], + help="This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +parser.add_argument("-epv", "--epilogue_visitor", default=None, + type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], + help="This option describes how thread blocks are scheduled on GPU") + +# Argument +parser.add_argument("-p", "--problem_size", + default=[128, 128, 128], nargs=3, type=int, + help="GEMM problem size M, N, K") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, + help="Scaling factor of A * B") +parser.add_argument("-beta", "--beta", default=0.0, type=float, + help="Scaling factor of C") +parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, + choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], + help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ + GemmSplitKParallel is used for parallel splitK") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +pycutlass.compiler.nvcc() + +np.random.seed(0) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +if (args.activation_function == "identity" + or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +swizzling_functor = getattr(cutlass, args.swizzling_functor) + +visitor = args.epilogue_visitor is not None + +if args.epilogue_visitor == "ColumnReduction": + class ColumnReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + beta * c + reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0]) + return D, reduction + epilogue_functor = ColumnReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +elif args.epilogue_visitor == "RowReduction": + class RowReduction_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + alpha: 'scalar', beta: 'scalar'): + # + D = alpha * accum + tanh.numpy(beta * c) + reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) + return D, reduction + epilogue_functor = RowReduction_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() + +elif args.epilogue_visitor == "RowBroadcast": + class RowBroadcast_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + vector: 'row', alpha: 'scalar', beta: 'scalar'): + # + T = accum + vector + scale_T = alpha * T + Z = relu.numpy(scale_T + beta * c) + return Z, T + epilogue_functor = RowBroadcast_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +elif args.epilogue_visitor == "ColumnBroadcast": + class ColumnBroadcast_(EpilogueVisitTree): + def __call__( + self, accum: 'tensor', c: 'tensor', + vector: 'column', alpha: 'scalar', beta: 'scalar'): + # + T = accum + vector + scale_T = leaky_relu.numpy(alpha * T, 0.2) + Z = scale_T + beta * c + return Z, T + epilogue_functor = ColumnBroadcast_( + epilogue_functor, tile_description, math_inst.element_accumulator, + C.alignment, element_epilogue, C.element) + epilogue_functor.initialize() +else: + epilogue_functor = epilogue_functor + +operation = GemmOperationUniversal( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + visitor=visitor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +if args.gemm_mode == "GemmSplitKParallel": + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + reduction_operation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +# User-provide inputs + +problem_size = cutlass.gemm.GemmCoord( + args.problem_size[0], args.problem_size[1], args.problem_size[2]) + +tensor_a_size = args.batch * problem_size.m() * problem_size.k() +if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(bfloat16) + else: + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(getattr(np, args.element_a)) +else: + tensor_A = np.random.uniform( + low=-2, high=2,size=(tensor_a_size,) + ).astype(getattr(np, args.element_a)) + +tensor_b_size = args.batch * problem_size.k() * problem_size.n() +if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(bfloat16) + else: + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(getattr(np, args.element_b)) +else: + tensor_B = np.random.uniform( + low=-2, high=2, size=(tensor_b_size,) + ).astype(getattr(np, args.element_b)) + +if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + tensor_c_size = args.batch * problem_size.n() + elif args.layout_c == "ColumnMajor": + tensor_c_size = args.batch * problem_size.m() + else: + raise ValueError(args.layout_c) + else: + tensor_c_size = args.batch * problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(getattr(np, args.element_c)) +else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + +tensor_D = np.zeros( + shape=(args.batch * problem_size.m() * problem_size.n(),) +).astype(getattr(np, args.element_c)) + +if args.epilogue_visitor == "RowReduction": + cta_n = args.threadblock_shape[1] + num_cta_n = (problem_size.n() + cta_n - 1) // cta_n + reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c)) + output_op = operation.epilogue_type( + D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "ColumnReduction": + cta_m = args.threadblock_shape[0] + num_cta_m = (problem_size.m() + cta_m - 1) // cta_m + reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c)) + output_op = operation.epilogue_type( + D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "RowBroadcast": + vector = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n())) + ).astype(getattr(np, args.element_c)) + tensor_t = np.empty_like(tensor_D) + output_op = operation.epilogue_type( + c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] + ) +elif args.epilogue_visitor == "ColumnBroadcast": + vector = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1)) + ).astype(getattr(np, args.element_c)) + tensor_t = np.empty_like(tensor_D) + output_op = operation.epilogue_type( + c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] + ) +else: + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=output_op, + gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode), + split_k_slices=args.split_k_slices, batch=args.batch +) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_arguments = ReductionArguments( + operation=reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], + partitions=args.split_k_slices, workspace=arguments.ptr_D, + destination=tensor_D, source=tensor_C, + output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +# run the host reference module +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) + +if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: + tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten() +tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + +if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]: + output_op.sync() + accum_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) + tensor_D_ref, reduction_ref = epilogue_functor( + accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), + tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), + args.alpha, args.beta + ) + tensor_D_ref = tensor_D_ref.flatten() + reduction_ref = reduction_ref.flatten() + assert np.allclose(reduction_ref, reduction, atol=1e-2) + +elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: + output_op.sync() + accum_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) + + tensor_D_ref, tensor_T_ref = epilogue_functor( + accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), + tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), + vector, args.alpha, args.beta) + + tensor_D_ref = tensor_D_ref.flatten() + tensor_T_ref = tensor_T_ref.flatten() + + assert np.array_equal(tensor_t, tensor_T_ref) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/gemm_grouped.py b/examples/40_cutlass_py/customizable/gemm_grouped.py new file mode 100644 index 00000000..40f2bc8d --- /dev/null +++ b/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -0,0 +1,287 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import numpy as np +import pycutlass +from pycutlass import * +import csv +import sys + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM Grouped kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU. \ + NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ + This parameter is passed in at present to match the APIs of other kernels. The parameter \ + is unused within the kernel") +# precompute mode +parser.add_argument("-pm", "--precompute_mode", + default="Device", type=str, choices=["Host", "Device"], + help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") +# arguments +parser.add_argument("-p", "--problem_size_dir", type=str, + help="path to the csv file contains the problem sizes") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +if args.activation_function == "identity": + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +swizzling_functor = getattr(cutlass, args.swizzling_functor) +precompute_mode = getattr(SchedulerMode, args.precompute_mode) + +operation = GemmOperationGrouped( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + precompute_mode=precompute_mode +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +pycutlass.compiler.add_module([operation, ]) + +reference_module = ReferenceModule(A, B, C) + +# get problems +problem_sizes = [] +with open(args.problem_size_dir) as csv_file: + reader = csv.reader(csv_file) + for row in reader: + problem_sizes.append( + cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + ) + +problem_count = len(problem_sizes) + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +problem_sizes_coord = [] +tensor_D_refs = [] + +for problem_size in problem_sizes: + if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) + else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + + if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) + else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + + if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + c_size = problem_size.n() + elif args.layout_c == "ColumnMajor": + c_size = problem_size.m() + else: + raise ValueError(args.layout_c) + else: + c_size = problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(getattr(np, args.element_c)) + else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + tensor_D_ref = reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, + args.alpha, args.beta, args.bias) + tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + tensor_D_refs.append(tensor_D_ref) + problem_sizes_coord.append(problem_size) + +arguments = GemmGroupedArguments( + operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +) + +operation.run(arguments) + +arguments.sync() + +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/examples/40_cutlass_py/grouped_gemm_problem_size.csv b/examples/40_cutlass_py/customizable/grouped_gemm_problem_size.csv similarity index 100% rename from examples/40_cutlass_py/grouped_gemm_problem_size.csv rename to examples/40_cutlass_py/customizable/grouped_gemm_problem_size.csv diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py index b9b7fabc..341177cf 100644 --- a/examples/40_cutlass_py/gemm.py +++ b/examples/40_cutlass_py/gemm.py @@ -29,417 +29,110 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ -import numpy as np -import pycutlass -from pycutlass import * -import cutlass -from bfloat16 import bfloat16 +""" +Basic example of using the CUTLASS Python interface to run a GEMM +""" import argparse +import numpy as np +import sys + +import cutlass +import pycutlass +from pycutlass import * +import util -# parse the arguments -parser = argparse.ArgumentParser( - description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'") - -# Operation description -# math instruction description -parser.add_argument("-i", "--instruction_shape", - default=[1, 1, 1], nargs=3, type=int, - help="This option describes the size of MMA op") -parser.add_argument("-ta", "--element_a", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor A') -parser.add_argument("-tb", "--element_b", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor B') -parser.add_argument("-tc", "--element_c", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor C and output tensor D') -parser.add_argument("-tacc", "--element_acc", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of accumulator') -parser.add_argument('-m', "--math", default="multiply_add", - type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, - choices=["Simt", 'TensorOp'], - help="This option describes whether you want to use tensor \ - cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") -# tile description -parser.add_argument("-b", "--threadblock_shape", - default=[128, 128, 8], nargs=3, type=int, - help="This option describes the tile size a thread block with compute") -parser.add_argument("-s", "--stages", default=4, - type=int, help="Number of pipelines you want to use") -parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, - help="This option describes the number of warps along M, N, and K of the threadblock") -parser.add_argument("-cc", "--compute_capability", default=80, - type=int, help="This option describes CUDA SM architecture number") -# A -parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor A") -parser.add_argument('-aa', '--alignment_a', default=1, - type=int, help="Memory alignement of input tensor A") -# B -parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor B") -parser.add_argument('-ab', '--alignment_b', default=1, - type=int, help="Memory alignment of input tensor B") -# C -parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor C and output tensor D") -parser.add_argument('-ac', '--alignment_c', default=1, - type=int, help="Memory alignment of input tensor C and output tensor D") -# epilogue -parser.add_argument("-te", "--element_epilogue", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') -parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", - type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], - help="This option describes the epilogue part of the kernel") -parser.add_argument("-epv", "--epilogue_visitor", default=None, - type=str, choices=['RowReduction', 'ColumnReduction', 'RowBroadcast', 'ColumnBroadcast'], help="epilogue visitor for more complex epilogues") -# swizzling -parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ - "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], - help="This option describes how thread blocks are scheduled on GPU") - -# Argument -parser.add_argument("-p", "--problem_size", - default=[128, 128, 128], nargs=3, type=int, - help="GEMM problem size M, N, K") -parser.add_argument("-alpha", "--alpha", default=1.0, type=float, - help="Scaling factor of A * B") -parser.add_argument("-beta", "--beta", default=0.0, type=float, - help="Scaling factor of C") -parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, - choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], - help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ - GemmSplitKParallel is used for parallel splitK") -parser.add_argument('-k', '--split_k_slices', default=1, - type=int, help="Number of split-k partitions. (default 1)") -parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") -parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") - -# Activation function -parser.add_argument("-activ", "--activation_function", default="identity", - choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") -parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, - help="addition arguments for activation") -parser.add_argument('--print_cuda', action="store_true", - help="print the underlying CUDA kernel") - +parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") +parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM") +parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM") +parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") try: args = parser.parse_args() except: sys.exit(0) -pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) -pycutlass.compiler.nvcc() +# Check that the device is of a sufficient compute capability +cc = util.get_device_cc() +assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70." + +alignment = 8 +assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment) +assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment) +assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment) np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) -math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment) +B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment) +C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment) +element_acc = cutlass.float32 +element_epilogue = cutlass.float32 math_inst = MathInstruction( - args.instruction_shape, element_a, element_b, - element_acc, opclass, math_operation + [16, 8, 8], # Shape of the Tensor Core instruction + A.element, B.element, element_acc, + cutlass.OpClass.TensorOp, + MathOperation.multiply_add ) tile_description = TileDescription( - args.threadblock_shape, args.stages, args.warp_count, + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) - -A = TensorDescription( - element_a, layout_a, args.alignment_a -) - -B = TensorDescription( - element_b, layout_b, args.alignment_b -) - -C = TensorDescription( - element_c, layout_c, args.alignment_c -) - -element_epilogue = getattr(cutlass, args.element_epilogue) -if (args.activation_function == "identity" - or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): - # - epilogue_functor = getattr(pycutlass, args.epilogue_functor)( - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -else: - epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( - getattr(pycutlass, args.activation_function)(element_epilogue), - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - -swizzling_functor = getattr(cutlass, args.swizzling_functor) - -visitor = args.epilogue_visitor is not None - -if args.epilogue_visitor == "ColumnReduction": - class ColumnReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + beta * c - reduction = reduction_op(D, "column", "Add", args.threadblock_shape[0]) - return D, reduction - epilogue_functor = ColumnReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -elif args.epilogue_visitor == "RowReduction": - class RowReduction_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - alpha: 'scalar', beta: 'scalar'): - # - D = alpha * accum + tanh.numpy(beta * c) - reduction = reduction_op(D, "row", "Add", args.threadblock_shape[1]) - return D, reduction - epilogue_functor = RowReduction_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() - -elif args.epilogue_visitor == "RowBroadcast": - class RowBroadcast_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - vector: 'row', alpha: 'scalar', beta: 'scalar'): - # - T = accum + vector - scale_T = alpha * T - Z = relu.numpy(scale_T + beta * c) - return Z, T - epilogue_functor = RowBroadcast_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -elif args.epilogue_visitor == "ColumnBroadcast": - class ColumnBroadcast_(EpilogueVisitTree): - def __call__( - self, accum: 'tensor', c: 'tensor', - vector: 'column', alpha: 'scalar', beta: 'scalar'): - # - T = accum + vector - scale_T = leaky_relu.numpy(alpha * T, 0.2) - Z = scale_T + beta * c - return Z, T - epilogue_functor = ColumnBroadcast_( - epilogue_functor, tile_description, math_inst.element_accumulator, - C.alignment, element_epilogue, C.element) - epilogue_functor.initialize() -else: - epilogue_functor = epilogue_functor +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) operation = GemmOperationUniversal( - arch=args.compute_capability, tile_description=tile_description, + arch=cc, tile_description=tile_description, A=A, B=B, C=C, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, - visitor=visitor -) + epilogue_functor=epilogue_functor) if args.print_cuda: print(operation.rt_module.emit()) operations = [operation, ] -if args.gemm_mode == "GemmSplitKParallel": - if (args.activation_function == "identity"): - epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - else: - epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( - getattr(pycutlass, args.activation_function)(element_epilogue), - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) - - reduction_operation = ReductionOperation( - shape=cutlass.MatrixCoord(4, 32 * C.alignment), - C=C, element_accumulator=element_acc, - element_compute=element_epilogue, - epilogue_functor=epilogue_functor_reduction, - count=C.alignment - ) - operations.append(reduction_operation) - +# Compile the operation pycutlass.compiler.add_module(operations) -# User-provide inputs +# Randomly initialize tensors +tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16) +tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16) +tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32) +tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32) -problem_size = cutlass.gemm.GemmCoord( - args.problem_size[0], args.problem_size[1], args.problem_size[2]) - -tensor_a_size = args.batch * problem_size.m() * problem_size.k() -if args.element_a != "int8": - if args.element_a == "bfloat16": - tensor_A = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) - ).astype(bfloat16) - else: - tensor_A = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) - ).astype(getattr(np, args.element_a)) -else: - tensor_A = np.random.uniform( - low=-2, high=2,size=(tensor_a_size,) - ).astype(getattr(np, args.element_a)) - -tensor_b_size = args.batch * problem_size.k() * problem_size.n() -if args.element_b != "int8": - if args.element_b == "bfloat16": - tensor_B = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) - ).astype(bfloat16) - else: - tensor_B = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) - ).astype(getattr(np, args.element_b)) -else: - tensor_B = np.random.uniform( - low=-2, high=2, size=(tensor_b_size,) - ).astype(getattr(np, args.element_b)) - -if args.element_c != "int8": - if args.bias: - if args.layout_c == "RowMajor": - tensor_c_size = args.batch * problem_size.n() - elif args.layout_c == "ColumnMajor": - tensor_c_size = args.batch * problem_size.m() - else: - raise ValueError(args.layout_c) - else: - tensor_c_size = args.batch * problem_size.m() * problem_size.n() - if args.element_c == "bfloat16": - tensor_C = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) - ).astype(bfloat16) - else: - tensor_C = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) - ).astype(getattr(np, args.element_c)) -else: - tensor_C = np.random.uniform( - low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) - ).astype(getattr(np, args.element_c)) - -tensor_D = np.zeros( - shape=(args.batch * problem_size.m() * problem_size.n(),) -).astype(getattr(np, args.element_c)) - -if args.epilogue_visitor == "RowReduction": - cta_n = args.threadblock_shape[1] - num_cta_n = (problem_size.n() + cta_n - 1) // cta_n - reduction = np.zeros(shape=(args.batch * problem_size.m() * num_cta_n,), dtype=getattr(np, args.element_c)) - output_op = operation.epilogue_type( - D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "ColumnReduction": - cta_m = args.threadblock_shape[0] - num_cta_m = (problem_size.m() + cta_m - 1) // cta_m - reduction = np.zeros(shape=(args.batch * problem_size.n() * num_cta_m,), dtype=getattr(np, args.element_c)) - output_op = operation.epilogue_type( - D=tensor_D, alpha=args.alpha, beta=args.beta, c=tensor_C, reduction=reduction, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "RowBroadcast": - vector = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(args.batch, 1, problem_size.n())) - ).astype(getattr(np, args.element_c)) - tensor_t = np.empty_like(tensor_D) - output_op = operation.epilogue_type( - c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] - ) -elif args.epilogue_visitor == "ColumnBroadcast": - vector = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(args.batch, problem_size.m(), 1)) - ).astype(getattr(np, args.element_c)) - tensor_t = np.empty_like(tensor_D) - output_op = operation.epilogue_type( - c=tensor_C, vector=vector, alpha=args.alpha, beta=args.beta, Z=tensor_D, T=tensor_t, problem_size=[problem_size.m(), problem_size.n()] - ) -else: - output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +problem_size = cutlass.gemm.GemmCoord(args.m, args.n, args.k) +alpha = 1. +beta = 0. arguments = GemmArguments( operation=operation, problem_size=problem_size, A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, - output_op=output_op, - gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode), - split_k_slices=args.split_k_slices, batch=args.batch -) - -if args.gemm_mode == "GemmSplitKParallel": - reduction_arguments = ReductionArguments( - operation=reduction_operation, - problem_size=[problem_size.m(), problem_size.n()], - partitions=args.split_k_slices, workspace=arguments.ptr_D, - destination=tensor_D, source=tensor_C, - output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), - bias = arguments.bias - ) + output_op=operation.epilogue_type(alpha, beta)) +# Run the operation operation.run(arguments) +arguments.sync() -if args.gemm_mode == "GemmSplitKParallel": - reduction_operation.run(reduction_arguments) - reduction_arguments.sync() -else: - arguments.sync() - -# run the host reference module +# Run the host reference module and compare to the CUTLASS result reference = ReferenceModule(A, B, C) -tensor_D_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) - -if args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: - tensor_D_ref = (tensor_D_ref.reshape((args.batch, problem_size.m(), problem_size.n())) + vector).flatten() -tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) - -if args.epilogue_visitor in ["RowReduction", "ColumnReduction"]: - output_op.sync() - accum_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) - tensor_D_ref, reduction_ref = epilogue_functor( - accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), - tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), - args.alpha, args.beta - ) - tensor_D_ref = tensor_D_ref.flatten() - reduction_ref = reduction_ref.flatten() - assert np.allclose(reduction_ref, reduction, atol=1e-2) - -elif args.epilogue_visitor in ["RowBroadcast", "ColumnBroadcast"]: - output_op.sync() - accum_ref = reference.run( - tensor_A, tensor_B, tensor_C, problem_size, 1.0, 0.0, args.bias, args.batch) - - tensor_D_ref, tensor_T_ref = epilogue_functor( - accum_ref.reshape((args.batch, problem_size.m(), problem_size.n())), - tensor_C.reshape((args.batch, problem_size.m(), problem_size.n())), - vector, args.alpha, args.beta) - - tensor_D_ref = tensor_D_ref.flatten() - tensor_T_ref = tensor_T_ref.flatten() - - assert np.array_equal(tensor_t, tensor_T_ref) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) try: assert np.array_equal(tensor_D, tensor_D_ref) except: assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) + print("Passed.") diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py index 46ea9fed..f62d8009 100644 --- a/examples/40_cutlass_py/gemm_grouped.py +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -29,253 +29,125 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################ -import pycutlass -from pycutlass import * -import csv +""" +Basic example of using the CUTLASS Python interface to run a grouped GEMM +""" import argparse +import numpy as np +import sys -# parse the arguments -parser = argparse.ArgumentParser( - description="Launch CUTLASS GEMM Grouped kernels from python") +import cutlass +import pycutlass +from pycutlass import * +import util -# Operation description -# math instruction description -parser.add_argument("-i", "--instruction_shape", - default=[1, 1, 1], nargs=3, type=int, - help="This option describes the size of MMA op") -parser.add_argument("-ta", "--element_a", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor A') -parser.add_argument("-tb", "--element_b", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor B') -parser.add_argument("-tc", "--element_c", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of elements in input tensor C and output tensor D') -parser.add_argument("-tacc", "--element_acc", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], - help='Data type of accumulator') -parser.add_argument('-m', "--math", default="multiply_add", - type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") -parser.add_argument('-op', "--opcode", default="simt", type=str, - choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ - cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') -# tile description -parser.add_argument("-b", "--threadblock_shape", - default=[128, 128, 8], nargs=3, type=int, - help="This option describes the tile size a thread block with compute") -parser.add_argument("-s", "--stages", default=4, - type=int, help="Number of pipelines you want to use") -parser.add_argument("-w", "--warp_count", default=[ - 4, 2, 1], nargs=3, type=int, - help="This option describes the number of warps along M, N, and K of the threadblock") -parser.add_argument("-cc", "--compute_capability", default=80, - type=int, help="This option describes CUDA SM architecture number") -# A -parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor A") -parser.add_argument('-aa', '--alignment_a', default=1, - type=int, help="Memory alignment of input tensor A") -# B -parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor B") -parser.add_argument('-ab', '--alignment_b', default=1, - type=int, help="Memory alignment of input tensor B") -# C -parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ - "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], - help="Memory layout of input tensor C and output tensor D") -parser.add_argument('-ac', '--alignment_c', default=1, - type=int, help="Memory alignment of input tensor C and output tensor D") -# epilogue -parser.add_argument("-te", "--element_epilogue", default="float32", type=str, - choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') -parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", - type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], - help="This option describes the epilogue part of the kernel") -# swizzling -parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ - "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], - help="This option describes how thread blocks are scheduled on GPU. \ - NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ - This parameter is passed in at present to match the APIs of other kernels. The parameter \ - is unused within the kernel") -# precompute mode -parser.add_argument("-pm", "--precompute_mode", - default="Device", type=str, choices=["Host", "Device"], - help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") -# arguments -parser.add_argument("-p", "--problem_size_dir", type=str, - help="path to the csv file contains the problem sizes") -parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") -parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") -parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") -# Activation function -parser.add_argument("-activ", "--activation_function", default="identity", - choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") -parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, - help="addition arguments for activation") -parser.add_argument('--print_cuda', action="store_true", - help="print the underlying CUDA kernel") +parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") try: args = parser.parse_args() except: sys.exit(0) -pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +# Check that the device is of a sufficient compute capability +cc = util.get_device_cc() +assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70." np.random.seed(0) -element_a = getattr(cutlass, args.element_a) -element_b = getattr(cutlass, args.element_b) -element_c = getattr(cutlass, args.element_c) -element_acc = getattr(cutlass, args.element_acc) -math_operation = getattr(MathOperation, args.math) -opclass = getattr(cutlass.OpClass, args.opcode) +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +alignment = 1 +A = TensorDescription(cutlass.float16, cutlass.ColumnMajor, alignment) +B = TensorDescription(cutlass.float16, cutlass.RowMajor, alignment) +C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment) +element_acc = cutlass.float32 +element_epilogue = cutlass.float32 math_inst = MathInstruction( - args.instruction_shape, element_a, element_b, - element_acc, opclass, math_operation + [16, 8, 8], # Shape of the Tensor Core instruction + A.element, B.element, element_acc, + cutlass.OpClass.TensorOp, + MathOperation.multiply_add ) tile_description = TileDescription( - args.threadblock_shape, args.stages, args.warp_count, + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape math_inst ) -layout_a = getattr(cutlass, args.layout_a) -layout_b = getattr(cutlass, args.layout_b) -layout_c = getattr(cutlass, args.layout_c) - -A = TensorDescription( - element_a, layout_a, args.alignment_a -) - -B = TensorDescription( - element_b, layout_b, args.alignment_b -) - -C = TensorDescription( - element_c, layout_c, args.alignment_c -) - -element_epilogue = getattr(cutlass, args.element_epilogue) -if args.activation_function == "identity": - epilogue_functor = getattr(pycutlass, args.epilogue_functor)( - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -else: - epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( - getattr(pycutlass, args.activation_function)(element_epilogue), - C.element, C.alignment, math_inst.element_accumulator, element_epilogue) -swizzling_functor = getattr(cutlass, args.swizzling_functor) -precompute_mode = getattr(SchedulerMode, args.precompute_mode) +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) operation = GemmOperationGrouped( - arch=args.compute_capability, tile_description=tile_description, + arch=cc, tile_description=tile_description, A=A, B=B, C=C, - epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, - precompute_mode=precompute_mode -) + epilogue_functor=epilogue_functor, + precompute_mode=SchedulerMode.Device) if args.print_cuda: print(operation.rt_module.emit()) -pycutlass.compiler.add_module([operation, ]) +operations = [operation, ] -reference_module = ReferenceModule(A, B, C) - -# get problems -problem_sizes = [] -with open(args.problem_size_dir) as csv_file: - reader = csv.reader(csv_file) - for row in reader: - problem_sizes.append( - cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) - ) +# Compile the operation +pycutlass.compiler.add_module(operations) +# Initialize tensors for each problem in the group +problem_sizes = [ + cutlass.gemm.GemmCoord(128, 128, 64), + cutlass.gemm.GemmCoord(512, 256, 128) +] problem_count = len(problem_sizes) +alpha = 1. +beta = 0. + tensor_As = [] tensor_Bs = [] tensor_Cs = [] tensor_Ds = [] -problem_sizes_coord = [] tensor_D_refs = [] +reference = ReferenceModule(A, B, C) + for problem_size in problem_sizes: - if args.element_a != "int8": - if args.element_a == "bfloat16": - tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.k(),))).astype(bfloat16) - else: - tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() - * problem_size.k(),))).astype(getattr(np, args.element_a)) - else: - tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() - * problem_size.k(),)).astype(getattr(np, args.element_a)) - - if args.element_b != "int8": - if args.element_b == "bfloat16": - tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() - * problem_size.n(),))).astype(bfloat16) - else: - tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() - * problem_size.n(),))).astype(getattr(np, args.element_b)) - else: - tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() - * problem_size.n(),)).astype(getattr(np, args.element_b)) - - if args.element_c != "int8": - if args.bias: - if args.layout_c == "RowMajor": - c_size = problem_size.n() - elif args.layout_c == "ColumnMajor": - c_size = problem_size.m() - else: - raise ValueError(args.layout_c) - else: - c_size = problem_size.m() * problem_size.n() - if args.element_c == "bfloat16": - tensor_C = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) - ).astype(bfloat16) - else: - tensor_C = np.ceil( - np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) - ).astype(getattr(np, args.element_c)) - else: - tensor_C = np.random.uniform( - low=-2, high=2, size=(problem_size.m() * problem_size.n(),) - ).astype(getattr(np, args.element_c)) - tensor_D = np.zeros( - shape=(problem_size.m() * problem_size.n(),) - ).astype(getattr(np, args.element_c)) + # Randomly initialize tensors + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16) + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16) + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32) + tensor_D = np.zeros(shape=(m * n,)).astype(np.float32) tensor_As.append(tensor_A) tensor_Bs.append(tensor_B) tensor_Cs.append(tensor_C) tensor_Ds.append(tensor_D) - tensor_D_ref = reference_module.run( - tensor_A, tensor_B, tensor_C, problem_size, - args.alpha, args.beta, args.bias) - tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + + # Run the reference GEMM + tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) tensor_D_refs.append(tensor_D_ref) - problem_sizes_coord.append(problem_size) arguments = GemmGroupedArguments( - operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, - output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) + operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(alpha, beta) ) +# Run the operation operation.run(arguments) - arguments.sync() +# Compare the CUTLASS result to the host reference result for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): try: assert np.array_equal(tensor_d, tensor_d_ref) diff --git a/examples/40_cutlass_py/util.py b/examples/40_cutlass_py/util.py new file mode 100644 index 00000000..d37bd045 --- /dev/null +++ b/examples/40_cutlass_py/util.py @@ -0,0 +1,60 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for interacting with device +""" + +from cuda import cudart + + +# Raises an exception if `result` returned an error. Otherwise returns the result. +def check_cuda_errors(result: list): + # `result` is of the format : (cudaError_t, result...) + err = result[0] + if err.value: + raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err))) + + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +# Returns the integer representation of the device compute capability +def get_device_cc(device: int = 0): + deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) + major = str(deviceProp.major) + minor = str(deviceProp.minor) + return int(major + minor) diff --git a/examples/41_fused_multi_head_attention/CMakeLists.txt b/examples/41_fused_multi_head_attention/CMakeLists.txt new file mode 100644 index 00000000..f6995579 --- /dev/null +++ b/examples/41_fused_multi_head_attention/CMakeLists.txt @@ -0,0 +1,44 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 41_fused_multi_head_attention_fixed_seqlen + fused_multihead_attention_fixed_seqlen.cu + ) + +cutlass_example_add_executable( + 41_fused_multi_head_attention_variable_seqlen + fused_multihead_attention_variable_seqlen.cu + ) + + +add_custom_target(41_fused_multi_head_attention +DEPENDS 41_fused_multi_head_attention_fixed_seqlen + 41_fused_multi_head_attention_variable_seqlen +) diff --git a/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h b/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h similarity index 90% rename from examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h rename to examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h index 9265b52b..0af53720 100644 --- a/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h +++ b/examples/41_fused_multi_head_attention/attention_scaling_coefs_updater.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include "cutlass/functional.h" diff --git a/examples/42_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h similarity index 77% rename from examples/42_fused_multi_head_attention/debug_utils.h rename to examples/41_fused_multi_head_attention/debug_utils.h index 8e482661..4d8bc7d4 100644 --- a/examples/42_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include #include diff --git a/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/examples/41_fused_multi_head_attention/default_fmha_grouped.h new file mode 100644 index 00000000..6931faf7 --- /dev/null +++ b/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -0,0 +1,284 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "fmha_grouped.h" +#include "gemm_kernel_utils.h" +#include "find_default_mma.h" +#include "attention_scaling_coefs_updater.h" +#include "mma_from_smem.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration, + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly + > +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator + >::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Updater; + + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator>; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage>; + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + +/// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped< + MM0, + MM1, + scalar_t, + accum_t, + output_t, + output_accum_t, + kSingleValueIteration, + GroupScheduleMode_ + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/42_fused_multi_head_attention/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue_pipelined.h similarity index 100% rename from examples/42_fused_multi_head_attention/epilogue_pipelined.h rename to examples/41_fused_multi_head_attention/epilogue_pipelined.h diff --git a/examples/42_fused_multi_head_attention/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue_rescale_output.h similarity index 80% rename from examples/42_fused_multi_head_attention/epilogue_rescale_output.h rename to examples/41_fused_multi_head_attention/epilogue_rescale_output.h index 4a6b771e..30b8427b 100644 --- a/examples/42_fused_multi_head_attention/epilogue_rescale_output.h +++ b/examples/41_fused_multi_head_attention/epilogue_rescale_output.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. diff --git a/examples/42_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h b/examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h similarity index 100% rename from examples/42_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h rename to examples/41_fused_multi_head_attention/epilogue_thread_apply_logsumexp.h diff --git a/examples/42_fused_multi_head_attention/find_default_mma.h b/examples/41_fused_multi_head_attention/find_default_mma.h similarity index 72% rename from examples/42_fused_multi_head_attention/find_default_mma.h rename to examples/41_fused_multi_head_attention/find_default_mma.h index 9cd64d6d..a20f615b 100644 --- a/examples/42_fused_multi_head_attention/find_default_mma.h +++ b/examples/41_fused_multi_head_attention/find_default_mma.h @@ -1,8 +1,39 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + /*! \file \brief Cutlass provides helper template functions to figure out the right datastructures to instanciate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instanciation priority rules, it will only create an MmaMultiStage with + instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h new file mode 100644 index 00000000..2ad7f14d --- /dev/null +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -0,0 +1,839 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "fmha_grouped_problem_visitor.h" +#include "gemm_kernel_utils.h" +#include "epilogue_rescale_output.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename MM0_, ///! Structure for computing P = Q @ K + typename MM1_, ///! Structure for computing O = P @ V + typename scalar_t_, + typename accum_t_, + typename output_t_, + typename output_accum_t_, + bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file + GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform +> +struct FMHAGrouped { +public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + using ProblemVisitor = FMHAGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes0; + GemmCoord *problem_sizes1; + + int problem_count; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementV ** ptr_V; + ElementO ** ptr_O; + ElementOAccum ** ptr_O_accum; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutO::Stride::LongIndex *ldo; + + // Whether causal masking is to be performed + bool causal; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): + problem_count(0), + threadblock_count(0), + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(nullptr), + ldk(nullptr), + ldv(nullptr), + ldo(nullptr), + causal(false), + host_problem_sizes(nullptr) + { + + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + ElementQ ** ptr_Q, + ElementK ** ptr_K, + ElementP ** ptr_P, + ElementV ** ptr_V, + ElementO ** ptr_O, + ElementOAccum ** ptr_O_accum, + typename LayoutQ::Stride::LongIndex *ldq, + typename LayoutK::Stride::LongIndex *ldk, + typename LayoutP::Stride::LongIndex *ldp, + typename LayoutV::Stride::LongIndex *ldv, + typename LayoutO::Stride::LongIndex *ldo, + bool causal, + GemmCoord *host_problem_sizes=nullptr + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_P(ptr_P), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), + ldq(ldq), + ldk(ldk), + ldv(ldv), + ldo(ldo), + causal(causal), + host_problem_sizes(host_problem_sizes) + { + + } + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementV ** ptr_V; + ElementO ** ptr_O; + ElementOAccum ** ptr_O_accum; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutO::Stride::LongIndex *ldo; + + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(nullptr), + ldk(nullptr), + ldv(nullptr), + ldo(nullptr), + causal(false) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), + ldq(args.ldq), + ldk(args.ldk), + ldv(args.ldv), + ldo(args.ldo), + causal(args.causal) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, tile_count); + threadblock_count = args.threadblock_count; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; + ldq = args.ldq; + ldk = args.ldk; + ldv = args.ldv; + ldo = args.ldo; + causal = args.causal; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + +private: + + // Parameters to be used by an individual tile + struct TileParams { + + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x; + } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O, + typename OutputTileIterator::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + (int32_t)kKeysPerBlock, num_keys - iter_key_start); + int32_t const& problem_size_0_k = problem_size0.k(); + int32_t const& problem_size_1_n = problem_size1.n(); + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue( + shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), + params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM) + }; + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also updates `accum` with accum[i] <- + // exp(accum[i] * scale + // - mi) + MM0::ScalingCoefsUpdater::update< + kQueriesPerBlock, + kFullColumns, + kIsFirst, + kKeepOutputInRF>( + accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iteratorC_tile_offset, + 1.0f / cutlass::fast_sqrt(float(problem_size0.k()))); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kKeepOutputInRF ? 1 + : ceil_div( + (int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv( + shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + const bool kIsFirst = true; + const bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h new file mode 100644 index 00000000..25284c2f --- /dev/null +++ b/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes1 because these determine + // shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/42_fused_multi_head_attention/fused_multihead_attention.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu similarity index 94% rename from examples/42_fused_multi_head_attention/fused_multihead_attention.cu rename to examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu index 45e35a80..70767393 100644 --- a/examples/42_fused_multi_head_attention/fused_multihead_attention.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -77,21 +77,17 @@ Examples: # Run an attention example with default setup - $ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen # Run an attention example with custom setup - $ ./examples/42_fused_multi_head_attention/42_fused_multi_head_attention --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true + Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). */ ///////////////////////////////////////////////////////////////////////////////////////////////// -#include -#include -#include #include -#include -#include #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" @@ -241,8 +237,8 @@ struct Options { for (int i = 0; i < batch_size; ++i) { // problems belonging to the same batch share the same seq len - int m_real = seq_length; // (rand() % seq_length); - int mkv_real = seq_length_kv; // (rand() % seq_length_kv); + int m_real = seq_length; + int mkv_real = seq_length_kv; int m = (m_real + alignment - 1) / alignment * alignment; int mkv = (mkv_real + alignment - 1) / alignment * alignment; int k0 = head_size; @@ -260,7 +256,6 @@ struct Options { problem_sizes0_real.push_back(problem0_real); problem_sizes1_real.push_back(problem1_real); } - } } } @@ -268,7 +263,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "42_fused_multi_head_attention\n\n" + out << "41_fused_multi_head_attention_fixed_seqlen\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" @@ -276,7 +271,7 @@ struct Options { << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" - << " --seq_length_kv= Sequence length in multi-head attention for K/V(default: --seq_length_kv=seq_length)\n" + << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" << " --use_mask= If true, performs padding-like masking in softmax.\n" << " --iterations= Number of profiling iterations to perform.\n" << " --reference-check= If true, performs reference check.\n" @@ -342,8 +337,7 @@ public: using ElementSoftmaxCompute = typename Attention::accum_t; using LayoutQ = cutlass::layout::RowMajor; - using LayoutK = cutlass::layout::RowMajor; - using LayoutK_T = cutlass::layout::ColumnMajor; // transposed + using LayoutK = cutlass::layout::ColumnMajor; using LayoutP = cutlass::layout::RowMajor; using LayoutV = cutlass::layout::RowMajor; using LayoutO = cutlass::layout::RowMajor; @@ -516,7 +510,7 @@ private: auto problem1 = options.problem_sizes1.at(i); ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); - ldk_host.at(i) = LayoutK::packed({problem0.n(), problem0.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); @@ -541,7 +535,6 @@ private: total_elements_P += elements_P; total_elements_V += elements_V; total_elements_O += elements_O; - } problem_sizes_device0.reset(problem_count()); @@ -641,7 +634,7 @@ private: float abs_diff = fabs(diff); float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); float relative_diff = abs_diff / abs_ref; - if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } @@ -661,7 +654,7 @@ private: cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); LayoutQ layout_Q(ldq_host.at(i)); - LayoutK_T layout_K(ldk_host.at(i)); + LayoutK layout_K(ldk_host.at(i)); LayoutP layout_P(ldp_host.at(i)); LayoutV layout_V(ldv_host.at(i)); LayoutO layout_O(ldo_host.at(i)); @@ -673,7 +666,7 @@ private: MatrixCoord extent_O{problem1.m(), problem1.k()}; cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); - cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); + cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); @@ -686,7 +679,7 @@ private: // Reference GEMM cutlass::reference::device::GemmComplex< ElementQ, LayoutQ, - ElementK, LayoutK_T, + ElementK, LayoutK, ElementP, LayoutP, ElementCompute, ElementAccumulator >( @@ -988,6 +981,40 @@ public: /////////////////////////////////////////////////////////////////////////////////////////////////// +template < + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration +> +int run_attention(Options& options) { + using Attention = AttentionKernel< + cutlass::half_t, // scalar_t + cutlass::arch::Sm80, // ArchTag + true, // Memory is aligned + kQueriesPerBlock, + kKeysPerBlock, + kSingleValueIteration + >; + + // + // Test and profile + // + + TestbedAttention testbed(options); + + Result result = testbed.profile_grouped(); + if (!result.passed) { + std::cout << "Profiling CUTLASS attention has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + return 0; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + int main(int argc, char const **args) { // @@ -1041,52 +1068,25 @@ int main(int argc, char const **args) { std::cerr << "--alignment=1 is the only supported value\n"; return -2; } - using ArchTag = cutlass::arch::Sm80; - constexpr bool kIs64x64 = true; - // Set grid size - constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; - constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; - if (kIs64x64 && options.head_size_v > kKeysPerBlock) { - std::cerr << "WARNING: you will get better performance with `kIs64x64=false`\n"; + // Determine kernel configuration based on head size. + // If head size is less than or equal to 64, each block operates over 64 queries and + // 64 keys, and parital results can be stored in the register file. + // If head size is greater than 64, each block operates over 32 queries and 128 keys, + // and partial results are stored in shared memory. + if (options.head_size_v > 64) { + static int const kQueriesPerBlock = 32; + static int const kKeysPerBlock = 128; + if (options.head_size_v <= kKeysPerBlock) { + return run_attention(options); + } else { + return run_attention(options); + } + } else { + static int const kQueriesPerBlock = 64; + static int const kKeysPerBlock = 64; + return run_attention(options); } - - constexpr bool kSingleValueIteration = true; - if (kSingleValueIteration && options.head_size_v > kKeysPerBlock) { - std::cerr << "ERROR : Use kSingleValueIteration to keep output in RF. " \ - "This requires to have `head_size <= kKeysPerBlock` " \ - "but head_size_v=" << options.head_size_v << " and kKeysPerBlock=" << kKeysPerBlock << "\n"; - return -2; - } - if (!kSingleValueIteration && options.head_size_v <= kKeysPerBlock) { - std::cerr << "WARNING: you will get better performance with `kSingleValueIteration=true` (keeps the output in RF rather than GMEM)\n"; - } - - using Attention = AttentionKernel< - cutlass::half_t, // scalar_t - ArchTag, - true, // memory is aligned - kQueriesPerBlock, - kKeysPerBlock, - kSingleValueIteration - >; - - // - // Test and profile - // - - TestbedAttention testbed(options); - - Result result = testbed.profile_grouped(); - if (!result.passed) { - std::cout << "Profiling CUTLASS attention has failed.\n"; - std::cout << "\nFailed\n"; - return -1; - } - - std::cout << "\nPassed\n"; - - return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_multi_head_attention/fused_multihead_attention.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu similarity index 62% rename from examples/41_multi_head_attention/fused_multihead_attention.cu rename to examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu index 86166daa..0738277b 100644 --- a/examples/41_multi_head_attention/fused_multihead_attention.cu +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -32,19 +32,58 @@ /*! \file \brief CUTLASS Attention Example. - This workload computes an attention example with non-fixed sequence length input. Pointers of arrays - are fed into grouped-GEMM functions fused with softmax for computation. + This workload computes a fused multi head attention that supports variable sequence lengths. + Because it keeps the attention matrix in shared memory, it's both faster and + uses less global memory. + + This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, + and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. + + Algorithm: + In short, we can compute the output incrementally in blocks of size B, + we just need to divide the final result by the sum of all coefficients in + the softmax (which we compute incrementally) with the following pseudo-code: + + ``` + s_prime = torch.zeros([num_queries, B]) + O = torch.zeros([num_queries, head_size_v]) + for i in range(0, K.shape[0], B): + si = exp((Q . K[i * B:(i+1) * B].t) * scale) + sum_coefs += attn_unscaled.sum(-1) + O += si . V[i * B:(i+1) * B] + O = O / s_prime + ``` + + In practice, and for numerical stability reasons, + we also substract the maximum so far (`mi`) before doing + the exponential. When we encounter new keys, the maximum + used to compute O so far (`m_prime`) can differ from the + current maximum, so we update O before accumulating with + + ``` + O = O * exp(m_prime - mi) + m_prime = mi + ``` + + Implementation details: + - `si` is stored in shared memory between the 2 back to back gemms + - we keep and accumulate the output + directly in registers if we can (`head_size_v <= 128`). + Otherwise, we store it & accumulate in global memory (slower) + - blocks are parallelized across the batch dimension, the number + of heads, and the query sequence size + Examples: - # Run an attention example with default setup (max sequence length = 1024, batch size = 16, head size = 64, head number = 12) - $ ./examples/41_multi_head_attention/41_multi_head_attention + # Run an attention example with default setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen - # Run an attention example with batch size = 64 and head number = 16 without checking the correctness - $ ./examples/41_multi_head_attention/41_multi_head_attention --head_number=16 --batch_size=64 --reference-check=false - - Acknowledgement: this example is inspired by the idea originally prototyped by ByteDance Inc. + # Run an attention example with custom setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true + Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). + Using grouped GEMM to handle variable sequence lengths is inspired by an idea originally prototyped by ByteDance Inc. */ ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -53,10 +92,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" @@ -71,16 +107,14 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/layout/matrix.h" -#include "cutlass/gemm/kernel/gemm_grouped.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_complex.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" #include "cutlass/fast_math.h" -#include "gemm_attention.h" + +#include "default_fmha_grouped.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -115,6 +149,8 @@ struct Options { bool error; bool reference_check; bool use_mask; + bool causal; + bool fixed_seq_length; std::vector problem_sizes0; std::vector problem_sizes1; @@ -126,9 +162,11 @@ struct Options { int head_number; int batch_size; int head_size; + int head_size_v; int seq_length; + int seq_length_kv; int iterations; - int cuda_streams; + int problem_count; // alpha0, alpha1 and beta are fixed // in this multi-head attention example @@ -136,6 +174,8 @@ struct Options { float alpha1; float beta; + cutlass::gemm::kernel::GroupScheduleMode scheduler_mode; + // // Methods // @@ -143,15 +183,20 @@ struct Options { Options(): help(false), error(false), - alignment(16), + alignment(1), reference_check(true), head_number(12), batch_size(16), head_size(64), + head_size_v(64), seq_length(1024), + seq_length_kv(1024), use_mask(false), iterations(20), - cuda_streams(0) + causal(false), + fixed_seq_length(false), + problem_count(batch_size * head_number), + scheduler_mode(cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { } // Parses the command line @@ -163,23 +208,50 @@ struct Options { return; } - cmd.get_cmd_line_argument("alignment", alignment, 16); + cmd.get_cmd_line_argument("alignment", alignment, 1); cmd.get_cmd_line_argument("head_number", head_number, 12); cmd.get_cmd_line_argument("batch_size", batch_size, 16); cmd.get_cmd_line_argument("head_size", head_size, 64); + cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); cmd.get_cmd_line_argument("seq_length", seq_length, 1024); + cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); cmd.get_cmd_line_argument("use_mask", use_mask, false); cmd.get_cmd_line_argument("iterations", iterations, 20); - cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("causal", causal, true); + cmd.get_cmd_line_argument("fixed_seq_length", fixed_seq_length, false); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-mode", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + if (scheduler_mode_strs.size() > 1) { + std::cerr << "Only one scheduler mode may be passed in" << std::endl; + error = true; + return; + } + std::string scheduler_mode_str = scheduler_mode_strs[0]; + if (scheduler_mode_str == "kDeviceOnly") { + scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly; + } else if (scheduler_mode_str == "kHostPrecompute") { + scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute; + } else { + std::cerr << "Unrecognized scheduler mode '" << scheduler_mode_str << "'" << std::endl; + error = true; + return; + } + } + + if (fixed_seq_length) { + std::cout << "NOTE: Better performance is expected for fixed-sized sequence length from 41_fused_multi_head_attention_fixed_seqlen." << std::endl; + } randomize_problems(); - } void randomize_problems() { - int problem_count = head_number * batch_size; + problem_count = head_number * batch_size; problem_sizes0.reserve(problem_count); problem_sizes1.reserve(problem_count); @@ -193,20 +265,38 @@ struct Options { for (int i = 0; i < batch_size; ++i) { // problems belonging to the same batch share the same seq len - int m_real = (rand() % seq_length); - int m = (m_real + 1 + alignment - 1) / alignment * alignment; - int n = m; - int k = head_size; + + int m_real, mkv_real; + if (fixed_seq_length) { + m_real = seq_length; + mkv_real = seq_length_kv; + } else { + m_real = (rand() % seq_length) + 1; + + // Only randomize seq_length_kv if it was set to a different value than + // seq_length originally. + if (seq_length != seq_length_kv) { + mkv_real = (rand() % seq_length_kv) + 1; + } else { + mkv_real = m_real; + } + } + + int m = (m_real + alignment - 1) / alignment * alignment; + int mkv = (mkv_real + alignment - 1) / alignment * alignment; + int k0 = head_size; + int k1 = head_size_v; for (int j = 0; j < head_number; ++j) { - cutlass::gemm::GemmCoord problem0(m, n, k); - cutlass::gemm::GemmCoord problem1(m, k, n); + cutlass::gemm::GemmCoord problem0(m, mkv, k0); + cutlass::gemm::GemmCoord problem1(m, k1, mkv); + problem_sizes0.push_back(problem0); problem_sizes1.push_back(problem1); if (use_mask) { - cutlass::gemm::GemmCoord problem0_real(m_real, m_real, k); - cutlass::gemm::GemmCoord problem1_real(m_real, k, m_real); + cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); + cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); problem_sizes0_real.push_back(problem0_real); problem_sizes1_real.push_back(problem1_real); } @@ -215,19 +305,31 @@ struct Options { } } + void print_problems() { + std::cout << " Running " << batch_size << " batches, each with " << head_number << " heads of size " << head_size << ":" << std::endl; + for (int i = 0; i < batch_size; ++i) { + int idx = i * head_number; + std::cout << " [" << i << "] seq_length = " << problem_sizes0[idx].m() << " seq_length_kv = " << problem_sizes0[idx].n() << std::endl; + } + } + /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "41_multi_head_attention\n\n" + out << "41_fused_multi_head_attention_variable_seqlen\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" - << " --seq_length= Max sequence length in multi-head attention (default: --seq_length=1024)\n" + << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" + << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" + << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" << " --use_mask= If true, performs padding-like masking in softmax.\n" << " --iterations= Number of profiling iterations to perform.\n" - << " --reference-check= If true, performs reference check.\n"; + << " --reference-check= If true, performs reference check.\n" + << " --causal= If true, uses causal masking.\n" + << " --fixed_seq_length= If true, uses the same sequence length for each item in the batch.\n"; return out; } @@ -236,15 +338,31 @@ struct Options { double gflops(double runtime_s) const { // Number of real-valued multiply-adds - int64_t fmas = int64_t(); + int64_t fops = int64_t(); - for (auto const & problem : problem_sizes0) { - // Two flops per multiply-add - fmas += problem.product() * 2; + for (int i = 0; i < problem_sizes0.size(); ++i) { + auto const& problem0 = problem_sizes0[i]; + auto const& problem1 = problem_sizes1[i]; + + for (int row = 0; row < problem0.m(); ++row) { + int num_cols0 = problem0.n(); + if (causal) { + num_cols0 = std::min(row + 1, num_cols0); + } + // P <- Q . K_t + fops += 2 * num_cols0 * problem0.k(); + // P <- exp(P - max(P)) + fops += 2 * num_cols0; + // S <- sum(P) + fops += num_cols0 - 1; + // O <- P . V + fops += 2 * num_cols0 * problem1.n(); + // O <- O / S + fops += num_cols0 * problem1.n(); + } } - - // Multiply another '2' because of the back-to-back GEMM problems in attention - return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + + return double(fops) / double(1.0e9) / runtime_s; } }; @@ -260,30 +378,34 @@ public: // Type definitions // - using ElementQ = typename Attention::ElementQ; - using ElementK = typename Attention::ElementK; - using ElementP = typename Attention::ElementP; - using ElementAccumulator = typename Attention::GemmGrouped0::ElementAccumulator; - using ElementV = typename Attention::ElementV; - using ElementO = typename Attention::ElementOutput; + using scalar_t = typename Attention::GemmKernel::scalar_t; + using accum_t = typename Attention::GemmKernel::accum_t; + using output_t = typename Attention::GemmKernel::output_t; + using output_accum_t = typename Attention::GemmKernel::output_accum_t; - using EpilogueOutputOp = typename Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ElementwiseFunctor; - using ElementCompute = typename EpilogueOutputOp::ElementCompute; + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementAccumulator = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; - using ElementNorm = typename Attention::ElementNorm; - using ElementSum = typename Attention::ElementSum; - using ElementSoftmaxCompute = typename Attention::ElementSoftmaxCompute; + using ElementCompute = accum_t; - using LayoutQ = typename Attention::LayoutQ; - using LayoutK = typename Attention::LayoutK; - using LayoutP = typename Attention::LayoutP; - using LayoutV = typename Attention::LayoutV; - using LayoutO = typename Attention::LayoutO; + using ElementNorm = accum_t; + using ElementSum = accum_t; + using ElementSoftmaxCompute = accum_t; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutP = cutlass::layout::RowMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; using MatrixCoord = typename LayoutP::TensorCoord; - using ProblemVisitor0 = typename Attention::GemmKernel0::ProblemVisitor; - using ProblemVisitor1 = typename Attention::GemmKernel1::ProblemVisitor; + static bool const kNeedsOutputAccumulatorBuffer = Attention::GemmKernel::kNeedsOutputAccumulatorBuffer; private: @@ -310,8 +432,6 @@ private: std::vector offset_P; std::vector offset_V; std::vector offset_O; - std::vector offset_Norm; - std::vector offset_Sum; std::vector ldq_host; std::vector ldk_host; @@ -332,20 +452,19 @@ private: cutlass::DeviceAllocation block_P; cutlass::DeviceAllocation block_V; cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_O_accumulate; cutlass::DeviceAllocation block_Norm; cutlass::DeviceAllocation block_Sum; cutlass::DeviceAllocation offset_P_Device; - cutlass::DeviceAllocation offset_Norm_Device; - cutlass::DeviceAllocation offset_Sum_Device; cutlass::DeviceAllocation ptr_Q; cutlass::DeviceAllocation ptr_K; cutlass::DeviceAllocation ptr_P; cutlass::DeviceAllocation ptr_V; cutlass::DeviceAllocation ptr_O; - cutlass::DeviceAllocation ptr_Max; - cutlass::DeviceAllocation ptr_Sum; + cutlass::DeviceAllocation ptr_O_accumulate; + public: @@ -382,7 +501,7 @@ private: Element scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; - int bits_output = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; @@ -444,8 +563,6 @@ private: int64_t total_elements_V = 0; int64_t total_elements_O = 0; - int64_t total_elements_partial_norm = 0; - ldq_host.resize(problem_count()); ldk_host.resize(problem_count()); ldp_host.resize(problem_count()); @@ -455,42 +572,35 @@ private: for (int32_t i = 0; i < problem_count(); ++i) { - auto problem = options.problem_sizes0.at(i); + auto problem0 = options.problem_sizes0.at(i); + auto problem1 = options.problem_sizes1.at(i); - ldq_host.at(i) = LayoutQ::packed({problem.m(), problem.k()}).stride(0); - ldk_host.at(i) = LayoutK::packed({problem.k(), problem.n()}).stride(0); - ldp_host.at(i) = LayoutP::packed({problem.m(), problem.n()}).stride(0); - ldv_host.at(i) = LayoutV::packed({problem.n(), problem.k()}).stride(0); - ldo_host.at(i) = LayoutO::packed({problem.m(), problem.k()}).stride(0); + ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); + ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); + ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); + ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); // m = n for attention problems. - int64_t non_leading_dim = ldp_host.at(i); - int64_t threadblock_n = Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape::kN; - int64_t threadblock_num = (ldp_host.at(i) + threadblock_n - 1) / threadblock_n; - - seqlen_host.at(i) = problem.m(); + seqlen_host.at(i) = problem0.m(); offset_Q.push_back(total_elements_Q); offset_K.push_back(total_elements_K); offset_P.push_back(total_elements_P); offset_V.push_back(total_elements_V); offset_O.push_back(total_elements_O); - offset_Norm.push_back(total_elements_partial_norm); - offset_Sum.push_back(total_elements_partial_norm); - int64_t elements_Q = problem.m() * problem.k(); - int64_t elements_K = problem.k() * problem.n(); - int64_t elements_P = problem.m() * problem.n(); - int64_t elements_V = problem.n() * problem.k(); - int64_t elements_O = problem.m() * problem.k(); - int64_t elements_norm = non_leading_dim * threadblock_num; + int64_t elements_Q = problem0.m() * problem0.k(); + int64_t elements_K = problem0.k() * problem0.n(); + int64_t elements_P = problem0.m() * problem0.n(); + int64_t elements_V = problem1.k() * problem1.n(); + int64_t elements_O = problem1.m() * problem1.n(); total_elements_Q += elements_Q; total_elements_K += elements_K; total_elements_P += elements_P; total_elements_V += elements_V; total_elements_O += elements_O; - total_elements_partial_norm += elements_norm; } @@ -527,23 +637,22 @@ private: block_P.reset(total_elements_P); block_V.reset(total_elements_V); block_O.reset(total_elements_O); - block_Norm.reset(total_elements_partial_norm); - block_Sum.reset(total_elements_partial_norm); + + if (kNeedsOutputAccumulatorBuffer) { + block_O_accumulate.reset(total_elements_O); + } offset_P_Device.reset(problem_count()); - offset_Norm_Device.reset(problem_count()); - offset_Sum_Device.reset(problem_count()); // sync offset with device cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); - cutlass::device_memory::copy_to_device(offset_Norm_Device.get(), offset_Norm.data(), offset_Norm.size()); - cutlass::device_memory::copy_to_device(offset_Sum_Device.get(), offset_Sum.data(), offset_Sum.size()); std::vector ptr_Q_host(problem_count()); std::vector ptr_K_host(problem_count()); std::vector ptr_P_host(problem_count()); std::vector ptr_V_host(problem_count()); std::vector ptr_O_host(problem_count()); + std::vector ptr_O_accumulate_host(problem_count()); std::vector ptr_norm_host(problem_count()); std::vector ptr_sum_host(problem_count()); @@ -553,8 +662,10 @@ private: ptr_P_host.at(i) = block_P.get() + offset_P.at(i); ptr_V_host.at(i) = block_V.get() + offset_V.at(i); ptr_O_host.at(i) = block_O.get() + offset_O.at(i); - ptr_norm_host.at(i) = block_Norm.get() + offset_Norm.at(i); - ptr_sum_host.at(i) = block_Sum.get() + offset_Sum.at(i); + + if (kNeedsOutputAccumulatorBuffer) { + ptr_O_accumulate_host.at(i) = block_O_accumulate.get() + offset_O.at(i); + } } ptr_Q.reset(problem_count()); @@ -572,11 +683,10 @@ private: ptr_O.reset(problem_count()); ptr_O.copy_from_host(ptr_O_host.data()); - ptr_Max.reset(problem_count()); - ptr_Max.copy_from_host(ptr_norm_host.data()); - - ptr_Sum.reset(problem_count()); - ptr_Sum.copy_from_host(ptr_sum_host.data()); + if (kNeedsOutputAccumulatorBuffer) { + ptr_O_accumulate.reset(problem_count()); + ptr_O_accumulate.copy_from_host(ptr_O_accumulate_host.data()); + } // // Initialize the problems of the workspace @@ -606,12 +716,12 @@ private: float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); float relative_diff = abs_diff / abs_ref; if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { - printf("diff = %f, rel_diff = %f, {%f, %f}.\n", abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } } - + return true; } @@ -621,7 +731,7 @@ private: bool passed = true; for (int32_t i = 0; i < problem_count(); ++i) { - cutlass::gemm::GemmCoord problem = options.problem_sizes0.at(i); + cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); LayoutQ layout_Q(ldq_host.at(i)); @@ -630,11 +740,11 @@ private: LayoutV layout_V(ldv_host.at(i)); LayoutO layout_O(ldo_host.at(i)); - MatrixCoord extent_Q{problem.m(), problem.k()}; - MatrixCoord extent_K{problem.k(), problem.n()}; - MatrixCoord extent_P{problem.m(), problem.n()}; - MatrixCoord extent_V{problem.n(), problem.k()}; - MatrixCoord extent_O{problem.m(), problem.k()}; + MatrixCoord extent_Q{problem0.m(), problem0.k()}; + MatrixCoord extent_K{problem0.n(), problem0.k()}; + MatrixCoord extent_P{problem0.m(), problem0.n()}; + MatrixCoord extent_V{problem1.k(), problem1.n()}; + MatrixCoord extent_O{problem1.m(), problem1.n()}; cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); @@ -646,6 +756,7 @@ private: cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); + cutlass::reference::device::TensorFill(view_Ref_O_device, ElementO(0)); // Reference GEMM cutlass::reference::device::GemmComplex< @@ -654,12 +765,12 @@ private: ElementP, LayoutP, ElementCompute, ElementAccumulator >( - problem, + problem0, ElementAccumulator(options.alpha0), view_Q, - Attention::GemmGrouped0::kTransformA, + Attention::GemmKernel::MM0::Mma::kTransformA, view_K, - Attention::GemmGrouped0::kTransformB, + Attention::GemmKernel::MM0::Mma::kTransformB, ElementAccumulator(options.beta), view_P, view_Ref_device, @@ -672,41 +783,49 @@ private: std::vector matrix_Ref(layout_P.capacity(extent_P)); cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); - std::vector vector_Norm_Ref(problem.m()); - std::vector vector_Sum_Ref(problem.m()); + std::vector vector_Norm_Ref(problem0.m()); + std::vector vector_Sum_Ref(problem0.m()); - int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem.n(); + int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); - // Compute softmax for referece matrix + // Compute softmax for reference matrix // Assumed a row-major storage - for (int m = 0; m < problem.m(); m++) { + for (int m = 0; m < problem0.m(); m++) { + int n_dim_row = n_dim; + if (options.causal) { + n_dim_row = std::min(m + 1, n_dim); + } ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); - for (int n = 1; n < n_dim; n++) { + for (int n = 1; n < n_dim_row; n++) { max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); } vector_Norm_Ref.at(m) = ElementNorm(max); ElementSoftmaxCompute sum = ElementSoftmaxCompute(); - for (int n = 0; n < n_dim; n++) { + for (int n = 0; n < n_dim_row; n++) { sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); } ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); vector_Sum_Ref.at(m) = ElementSum(inv_sum); - for (int n = 0; n < n_dim; n++) { + for (int n = 0; n < n_dim_row; n++) { view_Ref_host.ref().at({m, n}) = ElementP( std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum ); } + // Mask out the rest of the attention matrix + for (int n = n_dim_row; n < n_dim; ++n) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } } // when not using mask, problem_real and problem share the same sizes if (options.use_mask) { - for (int m = 0; m < problem.m(); m++) { - for (int n = n_dim; n < problem.n(); n++) { + for (int m = 0; m < problem0.m(); m++) { + for (int n = n_dim; n < problem0.n(); n++) { view_Ref_host.ref().at({m, n}) = ElementP(0); } } @@ -724,9 +843,9 @@ private: problem1, ElementAccumulator(options.alpha1), view_P, - Attention::GemmGrouped0::kTransformA, + Attention::GemmKernel::MM0::Mma::kTransformA, view_V, - Attention::GemmGrouped0::kTransformB, + Attention::GemmKernel::MM0::Mma::kTransformB, ElementAccumulator(options.beta), view_Ref_O_device, view_Ref_O_device, @@ -734,41 +853,20 @@ private: ); // Copy to host memory - - int64_t threadblock_n = Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape::kN; - int64_t threadblock_num = (problem.m() + threadblock_n - 1) / threadblock_n; - - std::vector vector_Norm(problem.m() * threadblock_num); - std::vector vector_Sum(problem.m() * threadblock_num); - - cutlass::device_memory::copy_to_host(vector_Norm.data(), block_Norm.get() + offset_Norm.at(i), vector_Norm.size()); - cutlass::device_memory::copy_to_host(vector_Sum.data(), block_Sum.get() + offset_Sum.at(i), vector_Sum.size()); - cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); std::vector matrix_O(layout_O.capacity(extent_O)); cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); - std::vector matrix_Ref_O(layout_O.capacity(extent_O)); + std::vector matrix_Ref_O(layout_O.capacity(extent_O)); cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); - bool verified_N = false; - bool verified_S = false; + bool verified_O = false; - - if (!verified_N) { - verified_N = verify_tensor_(vector_Norm, vector_Norm_Ref); - } - - if (!verified_S) { - verified_S = verify_tensor_(vector_Sum, vector_Sum_Ref); - } - - if (!verified_O) { verified_O = verify_tensor_(matrix_O, matrix_Ref_O); } - passed = passed && verified_N && verified_S && verified_O; + passed = passed && verified_O; if (!passed) { std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; @@ -777,14 +875,6 @@ private: std::cout << "Final matrix output is incorrect" << std::endl; } - if (!verified_N) { - std::cout << "Max is incorrect" << std::endl; - } - - if (!verified_S) { - std::cout << "Sum is incorrect" << std::endl; - } - return passed; } @@ -795,40 +885,18 @@ private: public: - /// Returns the number of threadblocks to launch if the kernel can run on the target - /// device. Otherwise, returns zero. - int sufficient() const { - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - int occupancy = Attention::GemmGrouped0::maximum_active_blocks(); - - return properties.multiProcessorCount * occupancy; - - } - /// Executes a CUTLASS Attention kernel and measures runtime. - Result profile_grouped() { + Result profile() { Result result; + result.passed = false; - int threadblock_count = sufficient(); + int threadblock_count = Attention::sufficient(options.problem_sizes1.data(), options.problem_count); // Early exit if (!threadblock_count) { - std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Attention kernel." << std::endl; + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped FMHA kernel." << std::endl; return result; } @@ -840,65 +908,40 @@ public: typename Attention::Arguments args( problem_sizes_device0.get(), problem_sizes_device1.get(), - problem_count(), + options.problem_count, threadblock_count, ptr_Q.get(), ptr_K.get(), ptr_P.get(), ptr_V.get(), ptr_O.get(), - ptr_Max.get(), - ptr_Sum.get(), - block_P.get(), - block_Norm.get(), - block_Sum.get(), - offset_P_Device.get(), - offset_Norm_Device.get(), - offset_Sum_Device.get(), + ptr_O_accumulate.get(), ldq.get(), ldk.get(), ldp.get(), ldv.get(), ldo.get(), - ElementAccumulator(options.alpha0), - ElementAccumulator(options.alpha1), - ElementAccumulator(options.beta), - options.head_number, - options.batch_size, - options.seq_length, - options.problem_sizes0.data(), - options.problem_sizes1.data(), - problem_sizes_device0_real.get() + options.causal, + options.problem_sizes1.data() ); - size_t workspace_size0 = ProblemVisitor0::kRequiresPrecomputation ?\ - ProblemVisitor0::get_workspace_size(options.problem_sizes0.data(),\ - problem_count(),\ - threadblock_count)\ - : 0; + Attention fmha; - size_t workspace_size1 = ProblemVisitor1::kRequiresPrecomputation ?\ - ProblemVisitor1::get_workspace_size(options.problem_sizes1.data(),\ - problem_count(),\ - threadblock_count)\ - : 0; + size_t workspace_size = fmha.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); - cutlass::DeviceAllocation workspace0(workspace_size0); - cutlass::DeviceAllocation workspace1(workspace_size1); - - Attention attention; - - result.status = attention.initialize(args, workspace0.get(), workspace1.get()); + result.status = fmha.initialize(args, workspace.get()); if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Attention kernel." << std::endl; + std::cerr << "Failed to initialize CUTLASS Grouped FMHA kernel." << std::endl; return result; } - result.status = attention.run(); + // Run the grouped FMHA object + result.status = fmha.run(); if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Attention kernel." << std::endl; + std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; return result; } @@ -906,7 +949,7 @@ public: result.error = cudaDeviceSynchronize(); if (result.error != cudaSuccess) { - std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error) << std::endl; return result; } @@ -920,13 +963,12 @@ public: } // - // Warm-up run of the grouped GEMM object + // Warm-up run of the grouped FMHA object // - - result.status = attention.run(); + result.status = fmha.run(); if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl; + std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; return result; } @@ -944,7 +986,7 @@ public: } } - // Record an event at the start of a series of GEMM operations + // Record an event at the start of a series of FMHA operations result.error = cudaEventRecord(events[0]); if (result.error != cudaSuccess) { std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; @@ -955,8 +997,8 @@ public: // Run profiling loop // - for (int iter = 0; iter < options.iterations; ++iter) { - attention(); + for (int iter = 0; iter < this->options.iterations; ++iter) { + fmha(); } // @@ -986,8 +1028,8 @@ public: } // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); // // Cleanup @@ -1000,8 +1042,10 @@ public: std::cout << std::endl; std::cout << "CUTLASS Attention:\n" << "====================================================" << std::endl; - std::cout << " " << " {max sequence length, head size, head number, batch size} = {" << options.seq_length \ - << ", " << options.head_size << ", " << options.head_number << ", " << options.batch_size << "}." << std::endl; + std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ + << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ + << ", " << options.batch_size << "}." << std::endl; + options.print_problems(); std::cout << std::endl; std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; std::cout << " " << "GFLOPs: " << result.gflops << std::endl; @@ -1012,6 +1056,65 @@ public: }; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ +> +int run_grouped(Options& options) { + using AttentionKernel = typename cutlass::gemm::kernel::DefaultFMHAGrouped< + cutlass::half_t, // scalar_t + cutlass::arch::Sm80, // ArchTag + true, // Memory is aligned + kQueriesPerBlock, + kKeysPerBlock, + kSingleValueIteration, + GroupScheduleMode_ + >::FMHAKernel; + + using FMHA = cutlass::gemm::device::GemmGrouped; + + // + // Test and profile + // + + TestbedAttention testbed(options); + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS attention has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + return 0; +} + + +template < + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration +> +int run_attention(Options& options) { + if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { + return run_grouped(options); + } else { + return run_grouped(options); + } +} + + /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { @@ -1046,7 +1149,7 @@ int main(int argc, char const **args) { // Options options; - + options.parse(argc, args); if (options.help) { @@ -1059,84 +1162,33 @@ int main(int argc, char const **args) { return -1; } - // - // Define the CUTLASS Attention type - // - - using ElementOutput = cutlass::half_t; - using ElementAccumulator = cutlass::half_t; - - using ElementQ = cutlass::half_t; - using ElementK = cutlass::half_t; - using ElementP = ElementOutput; - - using LayoutQ = cutlass::layout::RowMajor; - using LayoutK = cutlass::layout::ColumnMajor; - using LayoutP = cutlass::layout::RowMajor; - - static bool const UseMask = false; - - if (UseMask != options.use_mask) { - std::cerr << "UseMask and user-defined use_mask need to be consistant, " - << " aborted execution.\n"; + if (options.use_mask) { + std::cerr << "--use_mask is not supported at the moment\n"; + return -2; + } + if (options.alignment != 1) { + std::cerr << "--alignment=1 is the only supported value\n"; return -2; } - using OperatorClass = cutlass::arch::OpClassTensorOp; - using ArchTag = cutlass::arch::Sm80; - - using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 128, 32>; - using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; - - using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>; - using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 32>; - - static int const Stages0 = 3; - static int const Stages1 = 4; - - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - - using Attention = cutlass::FusedMultiHeadAttention< - ElementQ, - LayoutQ, - ElementK, - LayoutK, - ElementP, - LayoutP, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape0, - ThreadblockShape1, - WarpShape0, - WarpShape1, - InstructionShape, - Stages0, - Stages1, - UseMask - >; - - // - // Test and profile - // - - TestbedAttention testbed(options); - - if (!testbed.sufficient()) { - std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; - return 0; + // Determine kernel configuration based on head size. + // If head size is less than or equal to 64, each block operates over 64 queries and + // 64 keys, and parital results can be stored in the register file. + // If head size is greater than 64, each block operates over 32 queries and 128 keys, + // and partial results are stored in shared memory. + if (options.head_size_v > 64) { + static int const kQueriesPerBlock = 32; + static int const kKeysPerBlock = 128; + if (options.head_size_v <= kKeysPerBlock) { + return run_attention(options); + } else { + return run_attention(options); + } + } else { + static int const kQueriesPerBlock = 64; + static int const kKeysPerBlock = 64; + return run_attention(options); } - - Result result = testbed.profile_grouped(); - if (!result.passed) { - std::cout << "Profiling CUTLASS attention has failed.\n"; - std::cout << "\nFailed\n"; - return -1; - } - - std::cout << "\nPassed\n"; - - return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma.h b/examples/41_fused_multi_head_attention/gemm/custom_mma.h similarity index 54% rename from examples/42_fused_multi_head_attention/gemm/custom_mma.h rename to examples/41_fused_multi_head_attention/gemm/custom_mma.h index c0f1cd50..df7b6d15 100644 --- a/examples/42_fused_multi_head_attention/gemm/custom_mma.h +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include "custom_mma_multistage.h" diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma_base.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h similarity index 100% rename from examples/42_fused_multi_head_attention/gemm/custom_mma_base.h rename to examples/41_fused_multi_head_attention/gemm/custom_mma_base.h diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h similarity index 100% rename from examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h rename to examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h similarity index 100% rename from examples/42_fused_multi_head_attention/gemm/custom_mma_pipelined.h rename to examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h diff --git a/examples/42_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h similarity index 84% rename from examples/42_fused_multi_head_attention/gemm_kernel_utils.h rename to examples/41_fused_multi_head_attention/gemm_kernel_utils.h index eff9cbc6..a0d68d4d 100644 --- a/examples/42_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #include "cutlass/arch/mma.h" diff --git a/examples/42_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h similarity index 100% rename from examples/42_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h rename to examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h diff --git a/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 00000000..128829f2 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/examples/42_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h similarity index 100% rename from examples/42_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h rename to examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h diff --git a/examples/42_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h similarity index 100% rename from examples/42_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h rename to examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h diff --git a/examples/42_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h similarity index 95% rename from examples/42_fused_multi_head_attention/kernel_forward.h rename to examples/41_fused_multi_head_attention/kernel_forward.h index fb0855d5..e6880d31 100644 --- a/examples/42_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -1,3 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + #pragma once #ifdef HAS_PYTORCH diff --git a/examples/42_fused_multi_head_attention/mma_from_smem.h b/examples/41_fused_multi_head_attention/mma_from_smem.h similarity index 100% rename from examples/42_fused_multi_head_attention/mma_from_smem.h rename to examples/41_fused_multi_head_attention/mma_from_smem.h diff --git a/examples/41_multi_head_attention/gemm_attention.h b/examples/41_multi_head_attention/gemm_attention.h deleted file mode 100644 index 9990c0fb..00000000 --- a/examples/41_multi_head_attention/gemm_attention.h +++ /dev/null @@ -1,626 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holdvr nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Defines the FusedMultiHeadAttention Class - - The class contains the following: - 1) GEMM0 with epilogue fusion, - 2) GEMM1 with mainloop fusion, and - 3) A lightweight full softmax reduction kernel. - -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include -#include -#include -#include - -#include "cutlass/cutlass.h" -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h" -#include "cutlass/reduction/kernel/reduce_softmax_final.h" -#include "gemm_grouped_with_softmax_visitor.h" - -namespace cutlass { - -template < - typename ElementQ_, - typename LayoutQ_, - typename ElementK_, - typename LayoutK_, - typename ElementP_, - typename LayoutP_, - typename ElementCompute_, - typename OperatorClass_, - typename ArchTag_, - typename ThreadblockShape0_, - typename ThreadblockShape1_, - typename WarpShape0_, - typename WarpShape1_, - typename InstructionShape_, - int kStages0_, - int kStages1_, - bool UseMasking_ = false, - cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, - cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, - int Alignment = 128 / cutlass::sizeof_bits::value, - typename ElementSoftmax_ = ElementP_ -> -class FusedMultiHeadAttention { -public: - - using ElementQ = ElementQ_; - using ElementK = ElementK_; - using ElementP = ElementP_; - using ElementV = ElementK; - using ElementOutput = ElementP; - using ElementAccumulator = ElementCompute_; - - using LayoutQ = LayoutQ_; - using LayoutK = LayoutK_; - using LayoutP = LayoutP_; - using LayoutV = LayoutK; - using LayoutO = LayoutP; - - using ElementNorm = cutlass::half_t; - using ElementSum = cutlass::half_t; - using ElementSoftmaxCompute = float; - using LayoutNorm = cutlass::layout::RowMajor; - - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; - - using OperatorClass = OperatorClass_; - using ArchTag = ArchTag_; - - using ThreadblockShape0 = ThreadblockShape0_; - using WarpShape0 = WarpShape0_; - - using ThreadblockShape1 = ThreadblockShape1_; - using WarpShape1 = WarpShape1_; - - static int const Stages0 = kStages0_; - static int const Stages1 = kStages1_; - - using InstructionShape = InstructionShape_; - - using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; - - using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< - ElementOutput, 128 / cutlass::sizeof_bits::value, - ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>; - - using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration< - OperatorClass, ArchTag, ElementQ, ElementK, ElementP, - ElementAccumulator>::Operator; - static bool const kInternalTranspose = cutlass::platform::is_same::value; - - static bool const kUseMasking = UseMasking_; - - static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_; - static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_; - - using MapArguments = cutlass::gemm::kernel::detail::MapArguments< - ElementQ, - LayoutQ, - cutlass::ComplexTransform::kNone, - 8, - ElementK, - LayoutK, - cutlass::ComplexTransform::kNone, - 8, - LayoutP, - kInternalTranspose - >; - - using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< - typename MapArguments::ElementA, - typename MapArguments::LayoutA, - MapArguments::kAlignmentA, - typename MapArguments::ElementB, - typename MapArguments::LayoutB, - MapArguments::kAlignmentB, - ElementP, - typename MapArguments::LayoutC, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape0, - WarpShape0, - InstructionShape, - EpilogueOutputOp0, - ThreadblockSwizzle, - Stages0, - true, - Operator, - cutlass::gemm::SharedMemoryClearOption::kNone - >::GemmKernel; - - using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< - ThreadblockShape0, - DefaultGemmKernel::kThreadCount, - typename DefaultGemmKernel::Epilogue::OutputTileIterator, - typename EpilogueOutputOp0::ElementCompute, - ElementNorm, - ElementSum, - ElementSoftmaxCompute, - EpilogueOutputOp0, - kUseMasking - >; - - using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< - EpilogueVisitor, - typename DefaultGemmKernel::Epilogue - >::Epilogue; - - using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor< - typename DefaultGemmKernel::Mma, - Epilogue, - ThreadblockSwizzle, - kGroupScheduleMode0, - kInternalTranspose, - kUseMasking - >; - - using GemmGrouped0 = cutlass::gemm::device::GemmGrouped; - - using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< - ElementNorm, - ElementSum, - typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute, - typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape, - true - >; - - using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion< - ElementP, - LayoutP, - cutlass::ComplexTransform::kNone, - 128 / cutlass::sizeof_bits::value, - ElementV, - LayoutV, - cutlass::ComplexTransform::kNone, - 128 / cutlass::sizeof_bits::value, - ElementNorm, - LayoutNorm, - ElementOutput, - LayoutO, - ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape1, - WarpShape1, - InstructionShape, - EpilogueOutputOp1, - ThreadblockSwizzle, - Stages1, - kGroupScheduleMode1 - >::GemmKernel; - - using GemmGrouped1 = cutlass::gemm::device::GemmGrouped; - -public: - - /// Arguments class - struct Arguments { - cutlass::gemm::GemmCoord *problem_sizes0; - cutlass::gemm::GemmCoord *problem_sizes0_real; - cutlass::gemm::GemmCoord *problem_sizes1; - int problem_count; - int threadblock_count; - - ElementQ ** ptr_Q; - ElementK ** ptr_K; - ElementP ** ptr_P; - ElementP ** ptr_V; - ElementP ** ptr_O; - - ElementNorm **ptr_Max; - ElementSum **ptr_Sum; - - ElementP *block_P; - ElementNorm *block_Norm; - ElementSum *block_Sum; - int64_t *offset_P; - int64_t *offset_Norm_Device; - int64_t *offset_Sum_Device; - - typename LayoutQ::Stride::LongIndex *ldq; - typename LayoutK::Stride::LongIndex *ldk; - typename LayoutP::Stride::LongIndex *ldp; - typename LayoutP::Stride::LongIndex *ldv; - typename LayoutP::Stride::LongIndex *ldo; - - cutlass::gemm::GemmCoord *problem_sizes0_host; - cutlass::gemm::GemmCoord *problem_sizes1_host; - - ElementAccumulator alpha0; - ElementAccumulator alpha1; - ElementAccumulator beta; - - int head_number; - int batch_size; - int seq_length; - - typename ApplyFinalReductionDevice::Arguments reduction; - - // - // Methods - // - Arguments(): - problem_count(0), - threadblock_count(0), - ptr_Q(nullptr), - ptr_K(nullptr), - ptr_P(nullptr), - ptr_V(nullptr), - ptr_O(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr), - block_P(nullptr), - block_Norm(nullptr), - block_Sum(nullptr), - offset_P(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr), - ldq(nullptr), - ldk(nullptr), - ldp(nullptr), - ldv(nullptr), - ldo(nullptr), - head_number(0), - batch_size(0), - seq_length(0) - { - - } - - Arguments( - cutlass::gemm::GemmCoord *problem_sizes0, - cutlass::gemm::GemmCoord *problem_sizes1, - int problem_count, - int threadblock_count, - ElementQ ** ptr_Q, - ElementK ** ptr_K, - ElementP ** ptr_P, - ElementP ** ptr_V, - ElementP ** ptr_O, - ElementNorm **ptr_Max, - ElementSum **ptr_Sum, - ElementP *block_P, - ElementNorm *block_Norm, - ElementSum *block_Sum, - int64_t *offset_P, - int64_t *offset_Norm_Device, - int64_t *offset_Sum_Device, - typename LayoutQ::Stride::LongIndex *ldq, - typename LayoutK::Stride::LongIndex *ldk, - typename LayoutP::Stride::LongIndex *ldp, - typename LayoutP::Stride::LongIndex *ldv, - typename LayoutP::Stride::LongIndex *ldo, - ElementAccumulator alpha0, - ElementAccumulator alpha1, - ElementAccumulator beta, - int head_number, - int batch_size, - int seq_length, - cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr, - cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr, - cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr - ): - problem_sizes0(problem_sizes0), - problem_sizes1(problem_sizes1), - problem_count(problem_count), - threadblock_count(threadblock_count), - ptr_Q(ptr_Q), - ptr_K(ptr_K), - ptr_P(ptr_P), - ptr_V(ptr_V), - ptr_O(ptr_O), - ptr_Max(ptr_Max), - ptr_Sum(ptr_Sum), - block_P(block_P), - block_Norm(block_Norm), - block_Sum(block_Sum), - offset_P(offset_P), - offset_Norm_Device(offset_Norm_Device), - offset_Sum_Device(offset_Sum_Device), - ldq(ldq), - ldk(ldk), - ldp(ldp), - ldv(ldv), - ldo(ldo), - alpha0(alpha0), - alpha1(alpha1), - beta(beta), - head_number(head_number), - batch_size(batch_size), - seq_length(seq_length), - problem_sizes0_host(problem_sizes0_host), - problem_sizes1_host(problem_sizes1_host), - problem_sizes0_real(problem_sizes0_real), - reduction( - problem_sizes0, - block_Norm, - block_Sum, - offset_Norm_Device, - offset_Sum_Device - ) - { - - } - - - }; - - struct Params { - cutlass::gemm::GemmCoord *problem_sizes0; - cutlass::gemm::GemmCoord *problem_sizes0_real; - cutlass::gemm::GemmCoord *problem_sizes1; - int problem_count; - int threadblock_count; - - ElementQ ** ptr_Q; - ElementK ** ptr_K; - ElementP ** ptr_P; - ElementP ** ptr_V; - ElementP ** ptr_O; - - ElementNorm **ptr_Max; - ElementSum **ptr_Sum; - - ElementP *block_P; - ElementNorm *block_Norm; - ElementSum *block_Sum; - int64_t *offset_P; - int64_t *offset_Norm_Device; - int64_t *offset_Sum_Device; - - typename LayoutQ::Stride::LongIndex *ldq; - typename LayoutK::Stride::LongIndex *ldk; - typename LayoutP::Stride::LongIndex *ldp; - typename LayoutP::Stride::LongIndex *ldv; - typename LayoutP::Stride::LongIndex *ldo; - - cutlass::gemm::GemmCoord *problem_sizes0_host; - cutlass::gemm::GemmCoord *problem_sizes1_host; - - ElementAccumulator alpha0; - ElementAccumulator alpha1; - ElementAccumulator beta; - - int head_number; - int batch_size; - int seq_length; - - typename ApplyFinalReductionDevice::Params reduction; - - Params(): - problem_count(0), - threadblock_count(0), - ptr_Q(nullptr), - ptr_K(nullptr), - ptr_P(nullptr), - ptr_V(nullptr), - ptr_O(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr), - block_P(nullptr), - block_Norm(nullptr), - block_Sum(nullptr), - offset_P(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr), - ldq(nullptr), - ldk(nullptr), - ldp(nullptr), - ldv(nullptr), - ldo(nullptr), - problem_sizes0(nullptr), - problem_sizes1(nullptr), - problem_sizes0_real(nullptr), - head_number(0), - batch_size(0), - seq_length(0) - { - - } - - Params(Arguments const &args, void *workspace = nullptr): - problem_sizes0(args.problem_sizes0), - problem_sizes1(args.problem_sizes1), - problem_count(args.problem_count), - threadblock_count(args.threadblock_count), - ptr_Q(args.ptr_Q), - ptr_K(args.ptr_K), - ptr_P(args.ptr_P), - ptr_V(args.ptr_V), - ptr_O(args.ptr_O), - ptr_Max(args.ptr_Max), - ptr_Sum(args.ptr_Sum), - block_P(args.block_P), - block_Norm(args.block_Norm), - block_Sum(args.block_Sum), - offset_P(args.offset_P), - offset_Norm_Device(args.offset_Norm_Device), - offset_Sum_Device(args.offset_Sum_Device), - ldq(args.ldq), - ldk(args.ldk), - ldp(args.ldp), - ldv(args.ldv), - ldo(args.ldo), - problem_sizes0_host(args.problem_sizes0_host), - problem_sizes1_host(args.problem_sizes1_host), - problem_sizes0_real(args.problem_sizes0_real), - alpha0(args.alpha0), - alpha1(args.alpha1), - beta(args.beta), - head_number(args.head_number), - batch_size(args.batch_size), - seq_length(args.seq_length), - reduction(args.reduction) - { - - } - }; - - -private: - - Params params_; - GemmGrouped0 gemm_grouped0; - GemmGrouped1 gemm_grouped1; - - -public: - - /// Ctor - FusedMultiHeadAttention() { - - } - - /// Initialize - Status initialize(Arguments const &args, - void *workspace0 = nullptr, - void *workspace1 = nullptr) { - - params_ = Params(args); - - typename GemmGrouped0::Arguments args_gemm0( - params_.problem_sizes0, - params_.problem_count, - params_.threadblock_count, - params_.ptr_Q, - params_.ptr_K, - params_.ptr_P, - params_.ptr_P, - params_.ptr_Max, - params_.ptr_Sum, - params_.ldq, - params_.ldk, - params_.ldp, - params_.ldp, - typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments( - { - params_.alpha0, - params_.beta - } - ), - params_.problem_sizes0_host, - params_.problem_sizes0_real - ); - - - Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0); - - typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta); - - typename GemmGrouped1::Arguments args_gemm1( - params_.problem_sizes1, - params_.problem_count, - params_.threadblock_count, - epilogue_op1, - params_.ptr_P, - params_.ptr_V, - params_.ptr_O, - params_.ptr_O, - (void**)params_.ptr_Max, - (void**)params_.ptr_Sum, - params_.ldp, - params_.ldv, - params_.ldo, - params_.ldo, - params_.problem_sizes1_host - ); - - Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1); - - if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) { - return cutlass::Status::kSuccess; - }else{ - if (result0 != cutlass::Status::kSuccess) { - return result0; - }else{ - return result1; - } - } - } - - /// Run - Status run(cudaStream_t stream = nullptr) { - - Status result = gemm_grouped0.run(); - cudaError_t error_info; - - if (result != cutlass::Status::kSuccess) { - return cutlass::Status::kErrorInternal; - } - - int thread_per_block = 1024; - - dim3 final_reduction_grid(params_.head_number * params_.batch_size); - dim3 final_reduction_block(thread_per_block); - - cutlass::Kernel<<< - final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream - >>>(params_.reduction); - - error_info = cudaGetLastError(); - - if (error_info != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } - - result = gemm_grouped1.run(); - - if (result != cutlass::Status::kSuccess) { - return cutlass::Status::kErrorInternal; - } - - return cutlass::Status::kSuccess; - } - - /// Function call operator - Status operator()(cudaStream_t stream = nullptr) { - return run(stream); - } -}; - -} diff --git a/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h b/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h deleted file mode 100644 index 755c1252..00000000 --- a/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h +++ /dev/null @@ -1,522 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Grouped GEMM kernel with epilogue visitor customized for softmax -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" -#include "cutlass/semaphore.h" -#include "cutlass/util/device_memory.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform - bool Transposed_ = false, - bool UseMask_ = false -> -struct GemmGroupedWithEpilogueVistor { -public: - - using Mma = Mma_; - using Epilogue = Epilogue_; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - - using EpilogueVisitor = typename Epilogue::Visitor; - using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor; - static bool const kTransposed = Transposed_; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments< - typename Mma::IteratorA::Element, - typename Mma::IteratorA::Layout, - Mma::kTransformA, - Mma::IteratorA::AccessType::kElements, - typename Mma::IteratorB::Element, - typename Mma::IteratorB::Layout, - Mma::kTransformB, - Mma::IteratorB::AccessType::kElements, - typename Mma::LayoutC, - kTransposed - >; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename MapArguments::LayoutC; - - using ElementNorm = typename EpilogueVisitor::ElementNorm; - using ElementSum = typename EpilogueVisitor::ElementSum; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = GemmGroupedProblemVisitor< - ThreadblockShape, - kGroupScheduleMode, - kThreadCount, - kThreadCount, - kTransposed>; - - // - // Structures - // - - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmCoord *problem_sizes; - // when using mask, real problem sizes may not be aligned - // then we need to mask out unpadded elements in softmax - GemmCoord *problem_sizes_real; - int problem_count; - int threadblock_count; - - ElementA ** ptr_A; - ElementB ** ptr_B; - ElementC ** ptr_C; - ElementC ** ptr_D; - - ElementNorm **ptr_Max; - ElementSum **ptr_Sum; - - typename LayoutA::Stride::LongIndex *lda; - typename LayoutB::Stride::LongIndex *ldb; - typename LayoutC::Stride::LongIndex *ldc; - typename LayoutC::Stride::LongIndex *ldd; - - typename EpilogueVisitor::Arguments epilogue_visitor; - - // Only used by device-level operator - GemmCoord *host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments(): - problem_count(0), - threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr), - lda(nullptr), - ldb(nullptr), - ldc(nullptr), - ldd(nullptr), - host_problem_sizes(nullptr) - { - - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments( - GemmCoord *problem_sizes, - int problem_count, - int threadblock_count, - ElementA ** ptr_A, - ElementB ** ptr_B, - ElementC ** ptr_C, - ElementC ** ptr_D, - ElementNorm **ptr_Max, - ElementSum **ptr_Sum, - typename LayoutA::Stride::LongIndex *lda, - typename LayoutB::Stride::LongIndex *ldb, - typename LayoutC::Stride::LongIndex *ldc, - typename LayoutC::Stride::LongIndex *ldd, - typename EpilogueVisitor::Arguments epilogue_visitor_, - GemmCoord *host_problem_sizes=nullptr, - GemmCoord *problem_sizes_real=nullptr - ): - problem_sizes(problem_sizes), - problem_count(problem_count), - threadblock_count(threadblock_count), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_D(ptr_D), - ptr_Max(ptr_Max), - ptr_Sum(ptr_Sum), - lda(lda), - ldb(ldb), - ldc(ldc), - ldd(ldd), - epilogue_visitor(epilogue_visitor_), - host_problem_sizes(host_problem_sizes), - problem_sizes_real(problem_sizes_real) - { - - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - - typename ProblemVisitor::Params problem_visitor; - GemmCoord *problem_sizes_real; - int threadblock_count; - - ElementA ** ptr_A; - ElementB ** ptr_B; - ElementC ** ptr_C; - ElementC ** ptr_D; - - ElementNorm **ptr_Max; - ElementSum **ptr_Sum; - - typename LayoutA::Stride::LongIndex *lda; - typename LayoutB::Stride::LongIndex *ldb; - typename LayoutC::Stride::LongIndex *ldc; - typename LayoutC::Stride::LongIndex *ldd; - - typename EpilogueVisitor::Params epilogue_visitor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr), - lda(nullptr), - ldb(nullptr), - ldc(nullptr), - ldd(nullptr), - problem_sizes_real(problem_sizes_real) - { } - - CUTLASS_HOST_DEVICE - Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0): - problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), - threadblock_count(args.threadblock_count), - ptr_A(args.ptr_A), - ptr_B(args.ptr_B), - ptr_C(args.ptr_C), - ptr_D(args.ptr_D), - ptr_Max(args.ptr_Max), - ptr_Sum(args.ptr_Sum), - lda(args.lda), - ldb(args.ldb), - ldc(args.ldc), - ldd(args.ldd), - epilogue_visitor(args.epilogue_visitor), - problem_sizes_real(args.problem_sizes_real) - { - - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr, - int32_t tile_count = -1) { - - problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_Max = args.ptr_Max; - ptr_Sum = args.ptr_Sum; - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - problem_sizes_real = args.problem_sizes_real; - } - }; - - /// Shared memory storage structure - struct SharedStorage { - union { - typename Mma::SharedStorage main_loop; - struct { - typename Epilogue::SharedStorage epilogue; - typename EpilogueVisitor::SharedStorage visitor; - } epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - - -public: - - // - // Methods - // - - CUTLASS_DEVICE - GemmGroupedWithEpilogueVistor() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { - return Status::kSuccess; - } - - static Status can_implement(Arguments const &args) { - return Status::kSuccess; - } - - static size_t get_extra_workspace_size( - Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename Mma::LayoutC; - - // - // Problem visitor. - // - ProblemVisitor problem_visitor( - params.problem_visitor, - shared_storage.problem_visitor, - blockIdx.x); - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset( - int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, - 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - - ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{ - 0, - threadblock_offset.n() - }; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), - ptr_A, - {problem_size.m(), problem_size.k()}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), - ptr_B, - {problem_size.k(), problem_size.n()}, - thread_idx, - tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); - - ElementC *ptr_C = params.ptr_C[problem_idx]; - ElementC *ptr_D = params.ptr_D[problem_idx]; - - ElementNorm *ptr_Max = params.ptr_Max[problem_idx]; - ElementSum *ptr_Sum = params.ptr_Sum[problem_idx]; - - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); - - int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m(); - - typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C); - typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D); - - // - // Construct the epilogue visitor - // - - EpilogueVisitor epilogue_visitor( - params.epilogue_visitor, - shared_storage.kernel.epilogue.visitor, - problem_size.mn(), - thread_idx, - warp_idx, - lane_idx, - params_C, - params_D, - ptr_C, - ptr_D, - ptr_Max, - ptr_Sum, - threadblock_offset.mn(), - column_offset, - params.problem_sizes_real[problem_idx].mn() - ); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.kernel.epilogue.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Execute the epilogue operator to update the destination tensor - epilogue(epilogue_visitor, accumulators); - - // Next tile - problem_visitor.advance(gridDim.x); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/42_fused_multi_head_attention/CMakeLists.txt b/examples/42_ampere_tensorop_group_conv/CMakeLists.txt similarity index 96% rename from examples/42_fused_multi_head_attention/CMakeLists.txt rename to examples/42_ampere_tensorop_group_conv/CMakeLists.txt index c1c5c094..37bad683 100644 --- a/examples/42_fused_multi_head_attention/CMakeLists.txt +++ b/examples/42_ampere_tensorop_group_conv/CMakeLists.txt @@ -30,7 +30,7 @@ cutlass_example_add_executable( - 42_fused_multi_head_attention - fused_multihead_attention.cu + 42_ampere_tensorop_group_conv + ampere_tensorop_group_conv.cu ) diff --git a/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu b/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu new file mode 100644 index 00000000..2134642c --- /dev/null +++ b/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu @@ -0,0 +1,706 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +This example shows how to run group convolution kernels using functions and data structures +provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. + +There are 2 group conv mode: + 1. cutlass::conv::GroupMode::kSingleGroup + This mode is for large K problem size: k_per_group (K/groups) equals or larger than + threadblock_tile_N. One or multiple threadblocks calculate data of one group. + 2. cutlass::conv::GroupMode::kMultipleGroup + This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N. + One threadblock will calculate data from more than one group. + +Function profile_convolution_selecter() shows how to choose kernel with different group mode according +to problem size and threadblock_tile size. +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = float; // Data type of accumulator +using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementOutput = float; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNHWC; +using LayoutInputB = cutlass::layout::TensorNHWC; +using LayoutOutput = cutlass::layout::TensorNHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear combination + +// Analytic kernel and operation for single group problem size +using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic +>::Kernel; +using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +// Analytic kernel and operation for multiple group problem size +using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kMultipleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic +>::Kernel; +using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +// Optimized kernel and operation for single group problem size +using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kOptimized +>::Kernel; +using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + int groups; + bool reference_check; + bool measure_performance; + int iterations; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool optimized; + std::string tag; + + Options(): + help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 32), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + groups(1), + reference_check(false), + measure_performance(false), + iterations(20), + alpha(1), + beta(0), + optimized(false) { } + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || + (filter_size.n() % kAlignment)) { + + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update( + cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("optimized")) { + optimized = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + + cmd.get_cmd_line_argument("g", groups); + filter_size.c() = input_size.c() / groups; + + cmd.get_cmd_line_argument("u", conv_stride.row()); + cmd.get_cmd_line_argument("v", conv_stride.column()); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + if (filter_size.h() == 3 && filter_size.w() == 3) { + padding = {1, 1, 1, 1}; + } + else { + filter_size.h() = 1; + filter_size.w() = 1; + padding = {0, 0, 0, 0}; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "42_ampere_tensorop_group_conv example\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" + << " forward grouped convolution on tensors of layout NHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n= Input tensor extent N\n" + << " --h= Input tensor extent H\n" + << " --w= Input tensor extent W\n" + << " --c= Input tensor extent C\n" + << " --k= Filter extent K\n" + << " --r= Filter extent R\n" + << " --s= Filter extent S\n\n" + << " --g= Conv groups G\n\n" + << " --u= Conv stride_h\n\n" + << " --v= Conv stride_w\n\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --ref-check If set (true), reference check is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --tag= String to replicate across the first column in the results table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of multiply-adds = NPQK * CRS + int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result(): + runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) { } + + static std::ostream & print_header(std::ostream &out, Options const &options) { + + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs"; + + return out; + } + + std::ostream & print(std::ostream &out, int idx, Options const &options) { + + if (!options.tag.empty()) { + out << options.tag << ","; + } + + out + << "conv_" << idx << "," + << options.input_size.n() << "," + << options.input_size.h() << "," + << options.input_size.w() << "," + << options.input_size.c() << "," + << options.filter_size.n() << "," + << options.filter_size.h() << "," + << options.filter_size.w() << "," + << options.groups << "," + << runtime_ms << "," + << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one benchmark +template +Result profile_convolution(Options const &options) { + + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(7), + ElementInputA(-8), + 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(7), + ElementInputB(-8), + 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); + + // Fill tensor D on host with zeros + cutlass::reference::host::TensorFill( + tensor_d.host_view()); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Construct Conv2dProblemSize with user defined output size + cutlass::conv::Conv2dProblemSize problem_size( + options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices, + options.groups + ); + + // Construct Conv2dOperation::Argument structure with conv2d + // problem size, data pointers, and epilogue values + typename Conv2dOperation::Arguments arguments{ + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + }; + + // + // Initialize CUTLASS Convolution + // + + Conv2dOperation implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on device...\n"; + + // Compute with reference implementation + cutlass::reference::device::Conv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter + >( + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_ref_d.device_ref(), + options.alpha, + options.beta + ); + + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } else { + result.reference_check = cutlass::Status::kInvalid; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Result profile_convolution_selecter(Options const &options) { + int k_per_group = options.filter_size.n() / options.groups; + + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode + if (options.optimized) { + std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl; + exit(-1); + } else { + std::cout << "Select AnalyticMultipleGroupOperation\n"; + return profile_convolution(options); + } + } else { // SingleGroup mode + if (options.optimized) { + std::cout << "Select OptimizedSingleGroupOperation\n"; + return profile_convolution(options); + } else { + std::cout << "Select AnalyticSingleGroupOperation\n"; + return profile_convolution(options); + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution_selecter(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/42_fused_multi_head_attention/iterators/make_residual_last.h b/examples/42_fused_multi_head_attention/iterators/make_residual_last.h deleted file mode 100644 index 18b55100..00000000 --- a/examples/42_fused_multi_head_attention/iterators/make_residual_last.h +++ /dev/null @@ -1,66 +0,0 @@ -#pragma once - -#include "predicated_tile_access_iterator_residual_last.h" -#include "predicated_tile_iterator_residual_last.h" - -namespace cutlass { -namespace transform { -namespace threadblock { - -template -struct MakeIteratorResidualLast; - -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - int AccessSize, - bool Gather> -struct MakeIteratorResidualLast> { - using Iterator = PredicatedTileIteratorResidualLast< - Shape, - Element, - Layout, - AdvanceRank, - ThreadMap, - AccessSize, - Gather>; -}; - -template < - typename Shape, - typename Element, - typename Layout, - int AdvanceRank, - typename ThreadMap, - typename AccessType, - bool Gather> -struct MakeIteratorResidualLast> { - using Iterator = PredicatedTileAccessIteratorResidualLast< - Shape, - Element, - Layout, - AdvanceRank, - ThreadMap, - AccessType, - Gather>; -}; -} // namespace threadblock -} // namespace transform -} // namespace cutlass \ No newline at end of file diff --git a/examples/41_multi_head_attention/CMakeLists.txt b/examples/43_ell_block_sparse_gemm/CMakeLists.txt similarity index 96% rename from examples/41_multi_head_attention/CMakeLists.txt rename to examples/43_ell_block_sparse_gemm/CMakeLists.txt index 442048f6..78fbdb80 100644 --- a/examples/41_multi_head_attention/CMakeLists.txt +++ b/examples/43_ell_block_sparse_gemm/CMakeLists.txt @@ -1,4 +1,3 @@ - # Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # @@ -28,9 +27,8 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - cutlass_example_add_executable( - 41_multi_head_attention - fused_multihead_attention.cu + 43_ell_block_sparse_gemm + ell_block_sparse_gemm.cu ) diff --git a/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu new file mode 100644 index 00000000..7ea43895 --- /dev/null +++ b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu @@ -0,0 +1,740 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Block-Ell sparse gemm example. + + This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation. + Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format. + Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here: + https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell + Whereas matrix B is a dense matrix. + + Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices. + First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, + represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix), + represented by tensor_ell_idx in this example, that represent the column indices of the + corresponding non-zero blocks. All rows in the matrices must have the same number of blocks. + ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in + row-major order. + + Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format + for this example: + a_rows - Rows in the sparse matrix. + a_cols - Colums in the sparse matrix. + a_ell_blocksize - Size of the ELL-Blocks. + a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) + tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns) + tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is + (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) + tensor_b - Input dense matrix whose size is (a_cols * n) + tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n) + {a_rows, n, a_cols} - Problem size + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/ell_gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_uncompress.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool reference_check; + int iterations; + int cuda_streams; + int a_rows, n, a_cols; + int a_ell_num_columns; + int a_ell_blocksize; + int a_base; + float alpha; + float beta; + + // + // Methods + // + + Options(): + help(false), + reference_check(true), + iterations(20), + cuda_streams(0), + a_rows(1024), + n(1024), + a_cols(1024), + a_ell_num_columns(512), + a_ell_blocksize(16), + a_base(0), + alpha(1), + beta() + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + + cmd.get_cmd_line_argument("a_rows", a_rows, 1024); + cmd.get_cmd_line_argument("n", n, 1024); + cmd.get_cmd_line_argument("a_cols", a_cols, 1024); + + cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512); + cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16); + cmd.get_cmd_line_argument("a_base", a_base, 0); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "43_ell_block_sparse_gemm\n\n" + << " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --a_rows= Sets the number of the rows of the sparse matrix.\n" + << " --n= Sets the N dimension.\n" + << " --a_cols= Sets the number of columns of the sparse matrix.\n" + << " --a_ell_num_columns= Sets the actual number of columns of the Blocked-Ellpack format.\n" + << " --a_ell_blocksize= Sets the size of the ELL-Block.\n" + << " --a_base= Sets the base index.\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n" + << "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Testbed { +public: + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + +private: + + // + // Data members + // + + Options options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_ELL; + uint32_t seed; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c; + cutlass::HostTensor tensor_d; + + cutlass::HostTensor tensor_a_uncompressed; + cutlass::HostTensor reference_d; + + cutlass::HostTensor tensor_ell_idx; + +public: + + // + // Methods + // + + Testbed( + Options const &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian( + view, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity(), Element(1), Element()); + } else { + + // Fill with all 1s + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity(), Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_() { + tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns)); + tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n)); + tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n)); + tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n)); + + tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols)); + reference_d.resize(cutlass::make_Coord(options.a_rows, options.n)); + + tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize, + options.a_ell_num_columns / options.a_ell_blocksize)); + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021); + initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022); + initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023); + + if (init_ELL == cutlass::Distribution::Uniform) { + cutlass::reference::host::TensorFillRandomEllIdx( + tensor_ell_idx.host_view(), seed, + options.a_rows / options.a_ell_blocksize, + options.a_ell_num_columns / options.a_ell_blocksize, + options.a_cols / options.a_ell_blocksize); + + } else { + for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) { + for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) { + tensor_ell_idx.at({i, j}) = j+3; + } + } + } + + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ell_idx.sync_device(); + } + + /// Verifies the result is a GEMM + bool verify_() { + + bool passed = true; + + tensor_d.sync_host(); + + cutlass::uncompress_ell_block_sparse( + tensor_a_uncompressed.host_ref(), + tensor_a.host_ref(), + tensor_ell_idx.host_ref(), + options.a_rows, + options.a_cols, + options.a_ell_num_columns, + options.a_ell_blocksize + ); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + {options.a_rows, options.n, options.a_cols}, + options.alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + options.beta, + reference_d.host_ref(), + ElementAccumulator(0) + ); + + // Reference check + passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view()); + + if (!passed) { + std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; + + std::stringstream fname; + + fname << "error_43_ell_block_sparse_gemm" + << "mnk_" + << options.a_rows << "x" + << options.n << "x" + << options.a_cols << "_" + << options.a_ell_num_columns << "_" + << options.a_ell_blocksize << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results + << "alpha: " << ElementCompute(options.alpha) << "\n" + << "beta: " << ElementCompute(options.beta) << "\n" + << "block size: " << options.a_ell_blocksize << "\n" + << "\nA:\n" << tensor_a.host_view() << "\n" + << "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n" + << "\nB:\n" << tensor_b.host_view() << "\n" + << "\nC:\n" << tensor_c.host_view() << "\n" + << "\nD reference:\n" << reference_d.host_view() << "\n" + << "\nD computed:\n" << tensor_d.host_view() << "\n"; + + + return passed; + } + + return passed; + } + +public: + + /// Returns the number of threadblocks to launch if the kernel can run on the target + /// device. Otherwise, returns zero. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes a BlockedEll SpMM kernel and measures runtime. + Result profile() { + + Result result; + + // Early exit + if (!sufficient()) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + initialize_(); + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + {options.a_rows, options.n, options.a_cols}, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + tensor_ell_idx.device_data(), + options.a_ell_num_columns, + options.a_ell_blocksize, + options.a_base, + epilogue_op + ); + + // Initialize the GEMM object + Gemm gemm; + + result.status = gemm.initialize(args); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // Run the BlockedEll SpMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_(); + } + + // + // Warm-up run + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << "ELL Block Sparse GEMM (CUTLASS):\n" + << "====================================================" << std::endl; + + std::cout << std::endl; + std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Define the BlockedEll type + // + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits::value; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + constexpr int32_t kStages = 4; + using Gemm = typename cutlass::gemm::device::EllGemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementOutput, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + kStages, kAlignmentA, kAlignmentB>; + + // + // Profile it + // + + Testbed testbed(options); + + if (!testbed.sufficient()) { + std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; + return 0; + } + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/README.md b/examples/44_multi_gemm_ir_and_codegen/README.md new file mode 100644 index 00000000..6dc87312 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/README.md @@ -0,0 +1,63 @@ +This example provides utilities for generating back-to-back (B2B) GEMMs using CUTLASS. + +## Quick start +A configuration file containing the GEMMs to be fused together is located in [config.json](config.json). Edit +this to change the configuration that you would like to run. +```shell +cd ir_gen + +# Set up basic variables +out_dir=directory_to_emit_files +cutlass_dir=$(pwd)/../../.. +config_file=$(pwd)/../config.json + +# Generate code for GEMMs described in `config_file` +./generate.sh $config_file $out_dir $cutlass_dir + +# Build the generated code +cd $out_dir +mkdir build && cd build +cmake .. -DGPU_ARCHS="75;80" +make -j + +# Run the generated code with M=1024 K0=32 and Batch=1 +./sample 1024 32 1 +``` + +## Current restrictions +This experimental example has the following restrictions: +1. N tile should not exceed 256, or register spilling will occur. +2. Only FP16 is supported currently +3. Matrix A must be row major, matrix B must be column major, matrices C and D must be row major. + +## Copyright + +Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/44_multi_gemm_ir_and_codegen/config.json b/examples/44_multi_gemm_ir_and_codegen/config.json new file mode 100644 index 00000000..bb8757c0 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/config.json @@ -0,0 +1,32 @@ +{ + "0": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 256, 32], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + }, + "1": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 128, 256], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + }, + "2": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 64, 128], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + } +} diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h new file mode 100644 index 00000000..95eabb69 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" + +// #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "fused_bias_act_epilogue.h" +#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" +#include "output_tile_thread_map_for_fused_bias.h" +#include "default_thread_map_tensor_op_for_fused_bias.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultFusedBiasActEpilogueTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + OutputOp + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h new file mode 100644 index 00000000..46464b41 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + +*/ + +#pragma once + +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines the optimal thread map for TensorOp accumulator layouts +template < + typename ThreadblockShape_, + typename WarpShape_, + int PartitionsK, + typename Element_, + int ElementsPerAccess +> +struct DefaultThreadMapTensorOpForFusedBias { + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + static int const kPartitionsK = PartitionsK; + using Element = Element_; + static int const kElementsPerAccess = ElementsPerAccess; + + // + // Definitions + // + + struct Detail { + + /// Tensor Operations fundamentally perform operations on 8 rows + static int const kTensorOpRows = 8; + static int const kWarpSize = 32; + + static_assert( + !(ThreadblockShape::kM % WarpShape::kM) && + !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); + + /// Number of warps + using WarpCount = gemm::GemmShape< + ThreadblockShape::kM / WarpShape::kM, + ThreadblockShape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Number of participating threads + static int const kThreads = WarpCount::kCount * kWarpSize; + }; + + // + // ThreadMap + // + + /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap + using Type = OutputTileOptimalThreadMapBiasAct < + OutputTileShape, + OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, + Detail::kThreads, + kElementsPerAccess, + sizeof_bits::value + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h new file mode 100644 index 00000000..9e9f6928 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator without splitk +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename OutputOp_ ///< Output operator +> +class FusedBiasActEpilogue { + +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using OutputOp = OutputOp_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + +public: + + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +public: + + /// Constructor + CUTLASS_DEVICE + FusedBiasActEpilogue( + ){ } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + bool need_bias = output_op.is_source_needed(); + + if (need_bias) + compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); + else + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + + + } + + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + } + + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; + + + source_fragment.clear(); + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + source_iterator.load(source_fragment); + ++source_iterator; + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment, source_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + + CUTLASS_DEVICE + void compute_source_no_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + +}; + + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h new file mode 100644 index 00000000..97878fd3 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fast_math.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// RowArrangement determines how one or more warps cover a region of consecutive rows. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize, + bool Is2dTile +> +struct RowArrangementBiasAct; + +/// RowArrangement in which each warp's access is a 1D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + static int const kWarpSize = 32; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + static int const kIterationsRow = 1; + static int const kDeltaRow = 1; + static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; + static int const kDeltaColumn = kWarpSize * kElementsPerAccess; + + static int const kAccessWidth = kWarpSize; + static int const kAccessRows = 1; + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = WarpsRemaining; +}; + +/// RowArrangement in which each warp's access is a 2D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + + static int const kMemoryAccessSize = 4;//128; + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + struct Detail { + static int const kShapeRow = Shape::kRow / WarpsRemaining; + static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; + + static int const kTargetMemoryAccessWidth = + kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); + + static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; + }; + + static int const kAccessWidth = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + kWarpSize / Detail::kShapeRow + : const_min( + Detail::kShapeWidth, + const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) + )); + + static int const kAccessRows = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + Detail::kShapeRow + : const_min(Shape::kRow, kWarpSize / kAccessWidth)); + + static int const kIterationsRow = Detail::kShapeRow / kAccessRows; + static int const kDeltaRow = kAccessRows; + + static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; + static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; + + static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); + static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); + static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); + + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = 1; +}; + +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D space across warps to achieve several performance +/// objectives: +/// +/// - coalesced memory accesses in units of 16 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template < + typename Shape_, + typename Count_, + int Threads, + int ElementsPerAccess, + int ElementSize +> +struct OutputTileOptimalThreadMapBiasAct { + + using Shape = Shape_; + using Count = Count_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail { + + // Clusters + static int const kIterationsCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kCluster / kWarpCount + : 1); + + static int const kDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kCompactedDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kWarpPartitionsCluster = + ((Shape::kCluster > kWarpCount) ? + kWarpCount + : kWarpCount / Shape::kCluster); + + static int const kWarpsRemainingForGroups = + ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); + + // Groups + static int const kIterationsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kGroup / kWarpsRemainingForGroups + : 1); + + static int const kDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kCompactedDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kWarpPartitionsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + static int const kWarpsRemainingForRows = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + // Rows + using RowArrangement = detail::RowArrangementBiasAct< + Shape, + kWarpsRemainingForRows, + kElementsPerAccess, + kElementSize, + (Shape::kRow > kWarpsRemainingForRows) + >; + + // Warp partitions + using WarpPartitions = OutputTileShape< + RowArrangement::kWarpPartitionsColumn, + RowArrangement::kWarpPartitionsRow, + kWarpPartitionsGroup, + kWarpPartitionsCluster, + 1>; + + static int const kAccessWidth = RowArrangement::kAccessWidth; + static int const kAccessRows = RowArrangement::kAccessRows; + }; + + // + // Output + // + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kDeltaGroup, + Detail::kDeltaCluster, + 1>; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; + int group_offset = group_idx * Shape::kRow * Count::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + return MatrixCoord( + cluster_offset + group_offset + row_offset + lane_row_offset, + (column_offset + lane_col_offset) * kElementsPerAccess + ); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h new file mode 100644 index 00000000..c636841e --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FusedBiasActFragmentIteratorTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FusedBiasActFragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } + /// Stores a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void store(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + accumulators_[accumulator_access_offset] = frag_ptr[n]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h new file mode 100644 index 00000000..f7a98282 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of the accumulation tile shape (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Whether beta is zero + bool IsBetaZero_ > +class MmaTensorOpPureFragmentIterator; + + +// Partial specialization for col-major accumulator tile +// And Element type is the same as Accumulator Element type + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType src_fragment; + src_fragment.clear(); + + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; n++) { + for (int m = 0; m < MmaIterations::kRow; m++) { + int accumulator_access_offset = + (n + index_n) * AccumulatorIterations::kRow + m + index_m; + + frag_ptr[n * MmaIterations::kRow + m].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; + // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +// Partial specialization for row-major accumulator tile + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Accumulator Element type + using ElementAccumulator = ElementAccumulator_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + using FragmentAccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + + FragmentAccessType src_fragment; + src_fragment.clear(); + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; m++) { + for (int n = 0; n < MmaIterations::kColumn; n++) { + int accumulator_access_offset = + (m + index_m) * AccumulatorIterations::kColumn + n + index_n; + + frag_ptr[m * MmaIterations::kColumn + n].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py new file mode 100644 index 00000000..9f1cbf80 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_turing_and_volta as api_generator +import gen_sample as sample_creater +import gen_cmake as cmake_creater +import gen_verify as verify_creater +import gen_device as b2b_fused_generator +import replace_fix_impl_header + +import argparse +import os +import json + + +parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels") +parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate") +parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name") +parser.add_argument("--output-dir", default="", help="Specifies the output dir") +parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir") +parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.") +args = parser.parse_args() + +gen_name = args.gen_name + +cutlass_deps_dir = args.cutlass_dir + +output_dir = args.output_dir +output_dir += "/" + +cutlass_deps_root = args.gen_include_cutlass_dir +if cutlass_deps_root == '': + cutlass_deps_root = cutlass_deps_dir + "/include/" +cutlass_deps_root +='/' + + +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +if not os.path.exists(output_dir + "/" + "auto_gen"): + os.mkdir(output_dir + "/" + "auto_gen") + +if not os.path.exists(output_dir + "/" + "fixed_impl"): + os.mkdir(output_dir + "/" + "fixed_impl" ) + +if not os.path.exists(output_dir + "/" + "sample"): + os.mkdir(output_dir + "/" + "sample" ) + +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock") + +with open(args.config_file, 'r') as infile: + gemm_info_dict = json.load(infile) + +keys = sorted(gemm_info_dict.keys()) +fuse_gemm_info = [gemm_info_dict[k] for k in keys] + + +for_cutlass_gen_user_include_header_file = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", +] + +for_fused_wrapper = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", + "auto_gen/device/" + gen_name + ".h", + cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h", + cutlass_deps_root + "cutlass/cutlass.h", +] + +# Copy fixed implementation to the output directory +fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root) +fix_impl.gen_code() + +auto_gen_output_dir = output_dir + "/auto_gen/" +project_root = "" +turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir) +turing_plus.gen_code(75, 'hmma1688', False) + +api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +api.gen_code() + +# Generate C++ sample +os.system("cp ../leaky_bias.h " + output_dir + "/sample/") +os.system("cp ../utils.h " + output_dir + "/sample/") + +sample_dir = output_dir + "/sample/" +sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir) +sample.gen_cpp_sample() + +cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir) +cmake_gen.gen_code() + +verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +verify.gen_code() diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py new file mode 100644 index 00000000..aee17d05 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py @@ -0,0 +1,131 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +class gen_build_sys: + def __init__(self, cutlass_deps_dir, output_dir = "../"): + self.output_dir = output_dir + self.cutlass_deps_dir = cutlass_deps_dir + + def gen_top(self): + code = "" + code += '''\ +# Auto Generated code - Do not edit. + +cmake_minimum_required(VERSION 3.8) +project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA) +find_package(CUDAToolkit) +set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}}) +set(CUTLASS_PATH \"{cutlass_deps_dir}/include\") +set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\") +list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}}) +'''.format(cutlass_deps_dir=self.cutlass_deps_dir) + + code += '''\ +set(GPU_ARCHS \"\" CACHE STRING + \"List of GPU architectures (semicolon-separated) to be compiled for.\") + +if(\"${GPU_ARCHS}\" STREQUAL \"\") + set(GPU_ARCHS \"70\") +endif() + +foreach(arch ${GPU_ARCHS}) + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\") + if(SM STREQUAL 70 OR SM STREQUAL 75) + set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\") + set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\") + endif() +endforeach() + +set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\") +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\") + +set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +if(CMAKE_CXX_STANDARD STREQUAL \"11\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\") +endif() + +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\") + +set(COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDAToolkit_INCLUDE_DIRS} +) + +set(COMMON_LIB_DIRS + ${CUDAToolkit_LIBRARY_DIR} +) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH}) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH}) +''' + code += '''\ +include_directories( + ${COMMON_HEADER_DIRS} +) + +link_directories( + ${COMMON_LIB_DIRS} +) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-DGOOGLE_CUDA=1) + +add_executable(sample + sample/sample.cu + one_api.cu +) +target_link_libraries(sample PRIVATE + -lcudart + -lnvToolsExt + ${CMAKE_THREAD_LIBS_INIT} +) + +if(NOT DEFINED LIB_INSTALL_PATH) + set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR}) +endif() +''' + return code + + def gen_code(self): + top_code = self.gen_top() + with open(self.output_dir + "CMakeLists.txt", "w") as f: + f.write(top_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py new file mode 100644 index 00000000..7aeb5146 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ast + +fuse_gemm_info = [ + { + 'epilogue': { + 'tp': 'LeakyRelu', #'CustomizedLeaky_RELU' + 'bias': {'addbias': False, 'bias_tp': 'mat'}, + 'args': [('float', 'leaky_alpha', 1.3), ], + 'func': ''' +y = max(leaky_alpha * x, x) +y = y * x + ''' + } + }, + +] +class AnalysisNodeVisitor(ast.NodeVisitor): + def visit_Import(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_ImportFrom(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Assign(self,node): + print('Node type: Assign and fields: ', node._fields) + # print('Node type: Assign and targets value: ', node.targets, node.value) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_BinOp(self, node): + print('Node type: BinOp and fields: ', node._fields) + print('node op: ', type(node.op).__name__) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Expr(self, node): + print('Node type: Expr and fields: ', node._fields) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Num(self,node): + print('Node type: Num and fields: ', node._fields) + print('Node type: Num: ', node.n) + + def visit_Name(self,node): + print('Node type: Name and fields: ', node._fields) + print('Node type: Name and fields: ', type(node.ctx).__name__, node.id) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Str(self, node): + print('Node type: Str and fields: ', node._fields) + +class CodeVisitor(ast.NodeVisitor): + def visit_BinOp(self, node): + if isinstance(node.op, ast.Add): + node.op = ast.Sub() + self.generic_visit(node) + + def visit_Assign(self, node): + print('Assign %s' % node.value) + self.generic_visit(node) + + def visit_Name(self, node): + print("Name:", node.id) + self.generic_visit(node) + + + def visit_FunctionDef(self, node): + print('Function Name:%s'% node.name.op) + self.generic_visit(node) + func_log_stmt = ast.Print( + dest = None, + values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)], + nl = True, + lineno = 0, + col_offset = 0, + ) + node.body.insert(0, func_log_stmt) + +visitor = AnalysisNodeVisitor() + +code = \ +''' + +a=max(leaky_alpha * x, x +1) + +''' + +visitor.visit(ast.parse(code)) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py new file mode 100644 index 00000000..eb501464 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -0,0 +1,477 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import * + +import helper +import gen_ir + +import gen_kernel as gen_ker + + +class gen_device: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.raw_gemm_info = fuse_gemm_info + self.b2b_num = len(fuse_gemm_info) + self.user_header_file = user_header_file + self.args = {} + # device arg struct memebr + self.arg_member = [] + self.gen_class_name = gen_class_name + self.gen_kernel_name = gen_class_name + "Kernel" + self.tempalte_args = [] + self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int} + + self.file_name = output_dir + "/device/" +gen_class_name +".h" + self.sample_dir = output_dir + + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + self.this_file_root = output_dir + "/device/" + + self.first_use_1stage = False + + ## gen kernel + self.gen_kernel = gen_ker.gen_kernel(self.tempalte_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root) + + + def __check_arg_type(self, temp_arg): + if temp_arg in self.__tempalate_arg_list.keys(): + return self.__tempalate_arg_list[temp_arg] + + find_sub = False + for candidate_arg in self.__tempalate_arg_list.keys(): + if (temp_arg.find(candidate_arg) != -1): + return self.__tempalate_arg_list[candidate_arg] + + return 'typename' + + # def gen_B2b2bGemm_class(): + def set_arch(self, sm_cap, mma_tp): + if sm_cap == 75 or sm_cap == 80 or sm_cap == 86: + self.arch = "cutlass::arch::Sm" + str(sm_cap) + + if mma_tp is 'hmma1688': + self.mma_shape = [16, 8, 8] + self.mma_tp = 'hmma' + elif mma_tp is 'imma8816': + self.mma_tp = 'imma' + self.mma_shape = [8, 8, 16] + else: + return 0 + + def gen_include_header(self): + code = '''\ +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_root}cutlass/cutlass.h\" +#include \"{cutlass_root}cutlass/numeric_types.h\" +#include \"{cutlass_root}cutlass/arch/arch.h\" +#include \"{cutlass_root}cutlass/device_kernel.h\" + +#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\" + +#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{project_root}../kernel/b2b_gemm.h\" +#include \"{project_root}../kernel/default_b2b_gemm.h\" +'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root) + include_user_header = "" + for header in self.user_header_file: + include_user_header += "#include \"" + header + "\"\n" + return code + include_user_header + + def gen_code(self, sm_cap, mma_tp, ifprint = True): + self.set_arch(sm_cap, mma_tp) + + self.update_b2b_args() + print(self.fuse_gemm_info) + self.update_b2b_class_template_args() + + func_code = self.gen_all_func() + member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n" + + gen_code = gen_ir.gen_template_class(self.gen_class_name, self.tempalte_args, func_code + member_var_code) + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code))) + + if ifprint: + print(code) + + print("[INFO]: Gen device code output Dir: is ", self.file_name) + with open(self.file_name, 'w+') as f: + f.write(code) + + + gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage) + print(gen_kernel) + + def update_b2b_class_template_args(self): + for arg in self.args.keys(): + self.tempalte_args.append([self.__check_arg_type(arg), arg, self.args[arg]]) + + def update_b2b_args(self): + + self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp']) + self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format']) + + cnt = 0 + + warp_M_tile = 32 + + # Determine maxmimum N_tile + Max_Ntile = 0 + for layer in self.fuse_gemm_info: + n_tile = layer['mnk'][1] + if n_tile > Max_Ntile: + Max_Ntile = n_tile + if Max_Ntile >= 256: + warp_M_tile = 16 + + stages_temp = [] + + for layer in self.fuse_gemm_info: + cnt_str = str(cnt) + B_tp_str= 'ElementB' + cnt_str + B_format_str = 'LayoutB' + cnt_str + C_tp_str= 'ElementC' + cnt_str + C_format_str = 'LayoutC' + cnt_str + Acc_str = 'ElementAccumulator' + cnt_str + + self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp']) + self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format']) + self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp']) + self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format']) + self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp']) + + + mnk = layer['mnk'][:] + + tile_mnk = mnk[:] + + tile_mnk[2] = 32 # force the ktile is 32 + + #N tile gen + if mnk[1] > 1024: + assert(0) + elif mnk[1] > 512: + tile_mnk[1] = 1024 + elif mnk[1] > 256: + tile_mnk[1] = 512 + elif mnk[1] > 128: + tile_mnk[1] = 256 + elif mnk[1] > 64: + tile_mnk[1] = 128 + elif mnk[1] > 32: + tile_mnk[1] = 64 + else : + tile_mnk[1] = 32 + + if tile_mnk[1] == 512: + stages_temp.append(1) + else: + stages_temp.append(2) + + tile_mnk[0] = 4 * warp_M_tile + + + + epilogue_setted_type = helper.get_epilogue_tp(layer) + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + epilogue_str = 'EpilogueOutputOp' + cnt_str + if cnt != len(self.fuse_gemm_info) - 1: + n = layer['mnk'][1] + Fragments = tile_mnk[1] // 8 * 2 + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "" + else: + n = layer['mnk'][1] + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "" + + + + ThreadBlockShape_str = 'ThreadblockShape' + cnt_str + + self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + + WarpShape_str = 'WarpShape' + cnt_str + tile_mnk[0] = warp_M_tile + self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + cnt += 1 + + + self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp']) + self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format']) + + self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape) + self.args['OperatorClass'] = 'arch::OpClassTensorOp' + self.args['ArchTag'] = self.arch + self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle' + + + for i in range(self.b2b_num): + self.args[helper.var_idx('Stages', i)] = "2" + + self.args['AlignmentA'] = str(8) + self.args['AlignmentB'] = str(8) + self.args['SplitKSerial'] = 'false' + self.args['Operator'] = 'typename DefaultGemmConfiguration::Operator' + self.args['IsBetaZero'] = 'false' + + + def gen_using_kernel(self): + code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n" + code += " " + "ElementA,\n" + code += " " + "LayoutA,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementB", i) + ",\n" + code += " " + helper.var_idx("LayoutB", i) + ",\n" + code += " " + helper.var_idx("ElementC", i) + ",\n" + code += " " + helper.var_idx("LayoutC", i) + ",\n" + code += " " + helper.var_idx("ElementAccumulator", i) + ",\n" + code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n" + code += " " + helper.var_idx("ThreadblockShape", i) + ",\n" + code += " " + helper.var_idx("WarpShape", i) + ",\n" + + code += " " + "ElementD,\n" + code += " " + "LayoutD,\n" + code += " " + "InstructionShape,\n" + code += " " + "OperatorClass,\n" + code += " " + "ArchTag,\n" + code += " " + "ThreadblockSwizzle,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("Stages", i) + ",\n" + + + code += " " + "AlignmentA,\n" + code += " " + "AlignmentB,\n" + code += " " + "SplitKSerial,\n" + code += " " + "Operator,\n" + code += " " + "IsBetaZero_\n" + + code += ">::B2bGemmKernel;\n\n" + + return code + + def gen_args(self): + + def gen_arg_member(b2b_num): + data_members = [] + + for i in range(b2b_num): + member_type = "GemmCoord" + member_name = "problem_size_" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = "ref_A0" + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "TensorRef" + member_name = "ref_B" + str(i) + data_members.append((member_type, member_name)) + member_type = "TensorRef" + member_name = "ref_C" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = helper.var_idx("ref_D", b2b_num - 1) + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "typename EpilogueOutputOp" + str(i) + "::Params" + member_name = "epilogue" + str(i) + data_members.append((member_type, member_name)) + + data_members.append(('int', 'batch_count')) + + return data_members + + def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (): " + for i in range(inital_param_num): + final_param = ',' + if i == inital_param_num - 1: + final_param = '{ }' + constructs_code += data_members[i][1] + inital_value + final_param + + constructs_code += "\n" + return constructs_code + + def gen_arg_struct_ctor(struct_name, data_members): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (\n" + cnt = 0 + param_num = len(data_members) + for param in data_members: + final = ',\n' + if cnt == param_num - 1: + final = '\n):\n' + constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final + cnt += 1 + + cnt = 0 + for param in data_members: + final = '),\n' + if cnt == param_num - 1: + final = ") { }\n" + constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final + cnt += 1 + + constructs_code += "\n" + return constructs_code + + # (variable type, variable name) + struct_member = gen_arg_member(self.b2b_num) + self.arg_member = struct_member + + codeBody = "" + for each_member in struct_member: + codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n" + + codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n" + codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n" + struct_code = gen_ir.gen_struct("Arguments", codeBody) + return struct_code + + def gen_func_constructs(self): + code = self.gen_class_name +"() {}" + return code + + def gen_func_initialize(self): + code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \ + "// Determine grid shape\n" + \ + "ThreadblockSwizzle threadblock_swizzle;\n" + \ + "cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \ + " args.problem_size_0, \n" + \ + " { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \ + " args.batch_count);\n" + \ + "// Initialize the Params structure\n" + \ + "params_ = typename B2bGemmKernel::Params{\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.problem_size_", i) + ",\n" + code += " grid_shape,\n" + \ + " args.ref_A0.non_const_ref(),\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n" + code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n" + + code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.epilogue", i) + ",\n" + + code += " args.batch_count\n" + code += "};\n" + \ + "return Status::kSuccess;\n" + \ + "}\n" + return code + + def gen_func_run(self): + code = "Status run(cudaStream_t stream = nullptr) {\n" + \ + "\n" + \ + " ThreadblockSwizzle threadblock_swizzle;\n" + \ + "\n" + \ + " dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \ + " dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \ + "\n" + \ + " cudaError_t result;\n" + \ + "\n" + \ + " int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \ + " if (smem_size >= (48 << 10)) {\n" + \ + " result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \ + "\n" + \ + " if (result != cudaSuccess) {\n" + \ + " return Status::kErrorInternal;\n" + \ + " }\n" + \ + "\n" + \ + " result = cudaFuncSetAttribute(\n" + \ + " Kernel,\n" + \ + " cudaFuncAttributePreferredSharedMemoryCarveout, 100);\n" + \ + "\n" + \ + " if (result != cudaSuccess) {\n" + \ + " return Status::kErrorInternal;\n" + \ + " }\n" + \ + " }\n" + \ + " cutlass::Kernel<<>>(params_);\n" + \ + " result = cudaGetLastError();\n" + \ + " return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \ + " }\n" + + return code + def gen_func_operator(self): + opeartor_with_arg_code = "Status operator()(\n" + \ + " Arguments const &args,\n" + \ + " void *workspace = nullptr,\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = initialize(args, workspace);\n" + \ + " \n" + \ + " if (status == Status::kSuccess) {\n" + \ + " status = run(stream);\n" + \ + " }\n" + \ + " return status;\n" + \ + "}\n" + operator_code = "Status operator()(\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = run(stream);\n" + \ + " return status;\n" + \ + "}\n" + return opeartor_with_arg_code + "\n" + operator_code + + def gen_all_func(self): + return self.gen_using_kernel() + "\n" + \ + self.gen_args() + "\n" + \ + self.gen_func_constructs() + "\n" + \ + self.gen_func_initialize() + "\n" + \ + self.gen_func_run() + "\n" + \ + self.gen_func_operator() diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py new file mode 100644 index 00000000..c11ea9cc --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py @@ -0,0 +1,249 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper + + +indentation = " " + + +def append_word(word): + code = "" + code += word + code += " " + return code + + +def gen_namespace(namespace, codeBody): + code_gen = "namespace " + namespace + " {\n" + code_gen += codeBody + code_gen += "} // namespace " + namespace + "\n" + return code_gen + + +def gen_expression(type, lval, rval = None): + code_gen = "" + code_gen += append_word(type) + code_gen += append_word(lval) + if rval is not None: + code_gen += append_word("=") + code_gen += append_word(rval) + return code_gen + + +def gen_class(name, codeBody, inheritance_code = None): + code_gen = "" + if inheritance_code is None: + code_gen = "class " + name + "{\n" + else: + code_gen = "class " + name + " : "+ inheritance_code + "{\n" + code_gen += codeBody + code_gen += "}; // class " + name + "\n" + return code_gen + + +def gen_struct(name, codeBody, specialized = None): + specialized_code = "" + if specialized is not None: + specialized_code = "<" + specialized + ">" + code_gen = "struct " + name + specialized_code + "{\n" + code_gen += codeBody + code_gen += "}; // struct " + name + "\n" + return code_gen + + +def gen_template_arg(arg_type, arg_name, default_val = None): + rval = None + if default_val is not None: + rval = str(default_val) + + arg_typename = "" + if arg_type is int: + arg_typename = "int" + elif arg_type is bool: + arg_typename = "bool" + else: + arg_typename = "typename" + + internal_arg_name = arg_name + "_" + + code_gen = indentation + code_gen += gen_expression(arg_typename, internal_arg_name, rval) + + return code_gen + + +def gen_template_args(args, set_default = True): + arg_len = len(args) + cnt = 1 + code_gen = "" + for arg_tuple in args: + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + arg_default_val = None + if len(arg_tuple) == 3 and set_default: + arg_default_val = arg_tuple[2] + + code_gen += gen_template_arg(arg_type, arg_name, arg_default_val) + if cnt != arg_len: + code_gen += ",\n" + cnt += 1 + + return code_gen + + +def gen_template_head(args, set_default = True): + code_gen = "template <\n" + code_gen += gen_template_args(args, set_default) + code_gen += ">\n" + return code_gen + + +def export_template_args(args): + code_gen = "public:\n" + for arg_tuple in args: + code_gen += indentation + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + internal_arg_name = arg_name + "_" + + typename = "" + if arg_type is int: + typename = "static int const" + elif arg_type is bool: + typename = "static bool const" + else: + typename = "using" + + code_gen += gen_expression(typename, arg_name, internal_arg_name) + code_gen += ";\n" + return code_gen + + +def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None): + code_gen = "" + + code_gen += gen_template_head(args, set_default) + code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code) + + return code_gen + + +def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True): + code_gen = "" + code_gen += gen_template_head(args, set_default) + code = export_template_args(args) + codeBody + if export_args is False: + code = codeBody + code_gen += gen_struct(struct_name, code , speicalized) + + return code_gen + + +def gen_declare_template_struct(name, *params): + code = name + "<" + cnt = 0 + param_num = len(params) + for param in params: + final = ", " + if cnt == param_num - 1: + final = "" + code += param + final + cnt += 1 + code += ">;\n" + return code + + +def filtered_param(params, name_and_value_pair, keep_ = False): + rtn_template_args = [] + speicalized_template_args = [] + + for param in params: + param_name = "" + if len(param) >= 1: + param_name = param[1] + else: + param_name = param[0] + + hit_flag = False + set_value = "" + for n_v_pair in name_and_value_pair: + + filter_name = n_v_pair[0] + set_value = n_v_pair[1] + + if param_name == (filter_name + "_") or param_name == filter_name : + hit_flag = True + break + + + if hit_flag is False: + rtn_template_args.append(param) + + if hit_flag is True: + speicalized_template_args.append(set_value) + else: + if keep_ is True: + speicalized_template_args.append(param_name + "_") + else: + speicalized_template_args.append(param_name) + + + specialized_template_arg_str = helper.list_2_string(speicalized_template_args) + + return rtn_template_args, specialized_template_arg_str + + +def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True): + code = "void " + func_name + "(\n" + for arg in arg_lists: + arg_tp = arg[0] + arg_nm = arg[1] + code += " " + arg_tp + " " + arg_nm + ",\n" + code += "cudaStream_t stream)" + if only_declare : + return code + code += "{\n" + + code += code_body + "\n" + code += "}\n" + return code + + +def indent_level(code, level = 0): + rtn_code = "" + for i in range(level): + rtn_code += " " + + rtn_code += code + + return rtn_code diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py new file mode 100644 index 00000000..47a82b6f --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py @@ -0,0 +1,476 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper +import gen_threadblock as gen_tb + + +class gen_default_Gemm: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_B2bMma(self, specialized_template_args): + code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n" + code += specialized_template_args + code += ">::ThreadblockB2bMma;\n" + + # print(code) + return code + + def gen_epilogue(self): + epilogue_code = "" + epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n" + + epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n" + epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n" + epilogue_code += ">::Epilogue;\n" + + epilogue_code += "using B2bGemmKernel = kernel::B2bGemm;\n\n" + + return epilogue_code + + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/layout/matrix.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" + +#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\" +#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" + +#include \"../kernel/b2b_gemm.h\" +#include \"../threadblock/default_b2b_mma.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False) + + + filter_list = [] + filter_list.append(('Stages', 2)) + filter_list.append(("OperatorClass", "arch::OpClassTensorOp")) + filter_list.append(("ArchTag", "arch::Sm75")) + + for i in range(self.b2b_num): + filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor")) + + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True) + + + B2bMma_code = self.gen_B2bMma(speicalized_template_args) + epilogue_and_rest_code = self.gen_epilogue() + + gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False) + + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code))) + + return self.gen_include_header() + code + + +class gen_Kernel: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2bnum = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_Params(self): + gen_param = "" + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n" + gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n" + gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n" + gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n" + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n" + if i == self.b2bnum - 1: + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + else: + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + + + + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n" + + gen_param += " " + 'int batch_count' + ";\n" + gen_param += " " + 'int gemm_k_iterations_0' + ";\n" + + + return gen_param + + def gen_Memberfunc(self): + code_default = "\nCUTLASS_HOST_DEVICE\n" + code_default += "Params()" + + code_default += " { } \n\n" + + code_construct = "\nCUTLASS_HOST_DEVICE\n" + code_construct += "Params(\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n" + + code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n" + + code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n" + if i == self.b2bnum - 1: + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n" + else: + code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n" + + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n" + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n" + + code_construct += " " + "int batch_count = 1\n" + + code_construct += "):\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n" + + code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n" + code_construct += " " + "params_A0(ref_A0.layout()),\n" + code_construct += " " + "ref_A0(ref_A0),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n" + code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n" + + code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n" + + code_construct += " " + "batch_count(batch_count) {\n" + code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n" + + code_construct += "}\n" + + return code_default + code_construct + + def gen_using(self): + code_using = "" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n" + + code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n" + + + code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n" + code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n" + + code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc()) + + code_using += "union SharedStorage {\n" + code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n" + code_using += " " + "typename Epilogue::SharedStorage epilogue;\n" + code_using += "};\n" + + return code_using + + def gen_can_implement(self): + gen_code = "" + return gen_code + + def gen_operator_and_constr(self): + ctr_code = "CUTLASS_HOST_DEVICE\n" + ctr_code += self.gen_class_name + "() { } \n\n" + operator_code = "CUTLASS_DEVICE\n" + operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n" + operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n" + operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n" + operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n" + operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n" + operator_code += " " + " " + "return;\n" + operator_code += " " + "}\n" + + operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n" + operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n" + operator_code += " " + " " + "0\n" + operator_code += " " + "};\n" + + for i in range(self.b2bnum): + operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n" + operator_code += " " + " " + "0,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n" + operator_code += " " + "};\n" + + operator_code += " " + "int thread_idx = threadIdx.x;\n\n" + + operator_code += " " + "MatrixCoord threadblock_offset(\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n" + operator_code += " " + ");\n" + + operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n" + operator_code += " " + " " + "params.params_A0,\n" + operator_code += " " + " " + "params.ref_A0.data(),\n" + operator_code += " " + " " + "params.problem_size_0.mk(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "tb_offset_A0);\n" + + operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n" + + + for i in range (self.b2bnum): + operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n" + operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset" + ");\n" + operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n" + operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n" + + + operator_code += " " + "int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);\n" + operator_code += " " + "int lane_idx = threadIdx.x % 32;\n" + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n" + + operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n" + + operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n" + operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n" + + operator_code += " " + "src_accum.clear();\n" + operator_code += " " + "accumulators.clear();\n" + operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, " + + for i in range(self.b2bnum): + operator_code += helper.var_idx("iterator_B", i) + ", " + + operator_code += "src_accum" + if self.b2bnum != 1: + operator_code += ", " + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("output_op_", i) + ", " + + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("epilogue_", i) + ", " + + for i in range(self.b2bnum - 1): + final = ", " + if i == self.b2bnum - 2: + final ="" + operator_code += helper.var_idx("iterator_C", i) + final + operator_code += ");\n" + + operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n" + operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + + + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n" + + operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n" + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n" + + + operator_code += " " + "Epilogue epilogue(\n" + operator_code += " " + " " + "shared_storage.epilogue,\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "warp_idx,\n" + operator_code += " " + " " + "lane_idx\n" + operator_code += " " + ");\n" + + operator_code += " " + "epilogue(" + operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", " + operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", " + operator_code += "accumulators, " + operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n" + operator_code += "}\n" + + return ctr_code + operator_code + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\" +#include \"{cutlass_dir}cutlass/semaphore.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + def gen_code(self): + + template_param = [] + template_param.append(("typename", "B2bMma")) + template_param.append(("typename", "Epilogue")) + template_param.append(("typename", "ThreadblockSwizzle")) + template_param.append((bool, "SplitKSerial")) + + code_body = "" + code_body += self.gen_using() + code_body += self.gen_operator_and_constr() + + struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body) + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code))) + + return self.gen_include_header() + code + + + +class gen_kernel: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.template_param = template_param + + self.gen_class_name = "B2bGemm" + self.gen_kernel_name = gen_class_name + "Kernel" + self.tempalte_args = [] + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + # Include gen_threadBlock + self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root) + + self.file_dir = output_dir + "/kernel/" + + def gen_code(self, first_use_1stage): + + default_b2b_gemm = self.gen_default_b2b_gemm.gen_code() + + print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_gemm.h", "w+") as f: + f.write(default_b2b_gemm) + + kernel = self.gen_Kerenl.gen_code() + print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_gemm.h", "w+") as f: + f.write(kernel) + + # Call code to gen threadblock + self.gen_threadBlock.gen_code(first_use_1stage) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py new file mode 100644 index 00000000..5c456518 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_test: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = user_header_file + self.sample_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def gen_cpp_sample(self): + code = "/* Auto Generated code - Do not edit.*/\n" + code += "#include \n" + + code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" + code += "#include \"cutlass/cutlass.h\" \n" + + code += "#include \"../cutlass_irrelevant.h\" \n" + code += "#include \"../cutlass_verify.h\" \n" + + code += "#include \"leaky_bias.h\" \n" + + code += "#include \"utils.h\" \n" + + + + code += "int main(int args, char * argv[]) {\n" + code += " " + "int M = atoi(argv[1]);\n" + code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n" + code += " " + "if(args == 3);\n" + code += " " + " " + "K0 = atoi(argv[2]);\n" + code += " " + "int B = 1;\n" + code += " " + "if(args == 4);\n" + code += " " + " " + "B = atoi(argv[3]);\n" + + code += " " + "srand(1234UL);\n" + code += " " + "int device_id = 0;\n" + code += " " + "cudaGetDevice(&device_id);\n" + code += " " + "cudaDeviceProp prop;\n" + code += " " + "cudaGetDeviceProperties(&prop, device_id);\n" + code += " " + "int sm = prop.major *10 + prop.minor;\n" + code += "using ElementCompute = cutlass::half_t;\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n" + addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + if addbias: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n" + else: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n" + + code += " " + "size_t flops = 0;\n" + + for i in range(self.b2b_num): + m = self.fuse_gemm_info[i]['mnk'][0] + n = self.fuse_gemm_info[i]['mnk'][1] + k = self.fuse_gemm_info[i]['mnk'][2] + + bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i]) + + this_k = "K0" + if (i > 0): + this_k = str(k) + + code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n" + + code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n" + + code += " " + helper.var_idx("memory_unit Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n" + code += " " + helper.var_idx("memory_unit Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n" + + code += " " + helper.var_idx("Mat_A", i) + ".init();\n" + code += " " + helper.var_idx("Mat_B", i) + ".init();\n" + code += " " + helper.var_idx("Mat_C", i) + ".init();\n" + + + + code += " " + helper.var_idx("memory_unit Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n" + + params = [] + params.append("M") + params.append("B") + + params.append("Mat_A0.device_ptr") + for i in range(self.b2b_num): + params.append(helper.var_idx("Mat_B", i) + ".device_ptr") + params.append(helper.var_idx("Mat_C", i) + ".device_ptr") + if i != self.b2b_num-1: + params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr") + params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr") + + code += " " + "Param arguments = {\n" + code += " " + " " + "M,\n" + code += " " + " " + "K0,\n" + code += " " + " " + "B,\n" + + code += " " + " " + "reinterpret_cast(Mat_A0.device_ptr),\n" + cnt = 1 + for i in range(self.b2b_num): + bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n" + cnt += 1 + if bias_flag: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n" + cnt += 1 + else: + code += " " + " " + "reinterpret_cast(NULL),\n" + + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_value = str(arg[2]) + + code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n" + + if i != self.b2b_num - 1: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n" + else: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n" + + + + + code += " " + "TI(FUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + "one_api(arguments, sm, NULL);\n" + + code += " " + "}\n" + code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n" + + code += "\n" + + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmA = "K0" + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_value = str(epilogue_arg[2]) + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")" + code_this += " " + " },\n" + code_this += " " + " " + "B};\n" + + code += code_this + + + + code += " " + "TI(UNFUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + self.gen_class_name + "_verify(\n" + for i in range(self.b2b_num): + code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n" + code += " " + " " + " " + "NULL);\n" + + code += " " + "}\n" + code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n" + + code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \ + + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n" + + code += "\n\n}\n" + + with open(self.sample_dir + "sample.cu", "w+") as f: + f.write(code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py new file mode 100644 index 00000000..727a737c --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py @@ -0,0 +1,1013 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper + + +class gen_default_b2b_mma: + def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root): + self.gen_class_name = "DefaultB2bMma" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/arch/arch.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" + +#include \"../threadblock/b2b_mma_pipelined.h\" +#include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\" +#include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\" +#include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + + def gen_using_MmaCore(self, stage): + threadBlockShape = "ThreadblockShape" + warpShape = "WarpShape" + instrunctionShape = "InstructionShape" + Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore" + + + gen_code = "" + + for i in range(self.b2b_num): + code_using = "using MmaCore" + str(i) + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \ + helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \ + "ElementA", "LayoutA", \ + helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \ + helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \ + "OperatorClass", str(stage), "Operator") + return gen_code + + def gen_using_FusedAddBiasEpilouge(self): + gen_code = "" + for i in range(self.b2b_num - 1): + code_using = helper.var_idx("using FusedAddBiasEpilouge", i) + epilouge_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp" + template_args = helper.var_idx("::Epilogue" + + gen_code += code_using + " = " + epilouge_name + template_args + ";\n" + + return gen_code + + + def gen_using_Iterator(self): + code_using = "using IteratorA0" + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore0" + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapA" + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_") + + for i in range(self.b2b_num): + code_using = "using IteratorB" + str(i) + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore" + str(i) + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapB" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_") + + return gen_code + + def gen_fragment_iterator(self): + gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n" + + for i in range(1, self.b2b_num): + code_using = "using FragmentIteratorA" + str(i) + iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator" + curr_MmaCore = "MmaCore" + str(i) + prev_MmaCore = "MmaCore" + str(i - 1) + Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>" + Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>" + Curr_shape_kK = curr_MmaCore + "::Shape::kK" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \ + helper.var_idx("ElementAccumulator", i-1), "ElementA", \ + "AccumulatorLayout", "InstructionShape_", "true") + + return gen_code + + def gen_threadblockmma(self): + code_using = "using ThreadblockB2bMma" + iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined" + + MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape" + MmaPipelined_param_Mma0_iteratorA = "IteratorA0" + MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA" + MmaPipelined_param_Mma0_iteratorB = "IteratorB0" + MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB" + + MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", " + + for i in range(1, self.b2b_num): + MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape" + MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i) + MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i) + MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB" + + MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", " + + MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, " + + for i in range(self.b2b_num - 1): + epilouge_name = "EpilogueOutputOp" + str(i) + MmaPipelined_param_list += epilouge_name + ", " + + for i in range(self.b2b_num - 1): + epilouge_name = "FusedAddBiasEpilouge" + str(i) + MmaPipelined_param_list += epilouge_name + ", " + + for i in range(self.b2b_num): + MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy" + MmaPipelined_param_list += MmaPolicy + ", " + + + cnt = 0 + for i in range(self.b2b_num): + MmaStage = helper.var_idx("Stages", i) + final = ", " + if cnt == self.b2b_num - 1: + final = "" + MmaPipelined_param_list += MmaStage + final + cnt += 1 + + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list) + + return gen_code + + + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False) + + # Generate specialized template struct + + mmacore_codebody = self.gen_using_MmaCore(2) + iterator_codebody = self.gen_using_Iterator() + fragment_iterator_codebody = self.gen_fragment_iterator() + epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilouge() + threadBlockMma = self.gen_threadblockmma() + specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma + + # Specialize layout C -> cutlass::layout::RowMajor + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True) + + gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False) + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code))) + + return self.gen_include_header() + code + + +class gen_b2b_mme_pipelined: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bMmaPipelined" + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/array.h\" +#include \"{cutlass_dir}cutlass/aligned_buffer.h\" +#include \"{cutlass_dir}cutlass/numeric_conversion.h\" + +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/matrix_shape.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\" + +#include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root) + return code + + + def gen_using(self): + code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n" + + code_using += "using Base = B2bMmaBase<" + for i in range(self.b2b_num): + code_using += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Stage", i) + "_, " + code_using = code_using[: -2] + ">;\n" + + + for i in range(self.b2b_num): + code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n" + code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n" + code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n" + + for i in range(self.b2b_num - 1): + code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n" + + code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n" + code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n" + + code_using += "private:\n" + code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n" + code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n" + + for i in range(1, self.b2b_num): + code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n" + code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n" + + code_using += "protected:\n" + + code_using += "SmemIteratorA0 smem_iterator_A_;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n" + + return code_using + + + def gen_operator(self, first_use_1stage = False): + code = "" + def gen_operator_param(b2b_num): + param_code = "" + param_code += "int gemm_k_iterations_0,\n" + param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n" + param_code += "IteratorA0 iterator_A,\n" + + for i in range(b2b_num): + param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n" + + param_code += "FragmentC0 const &src_accum, \n" + + for i in range(b2b_num - 1): + param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n" + + + param_code += "TransformA0 transform_A0 = TransformA0(), \n" + + for i in range(b2b_num): + final = "(),\n" + if i == b2b_num - 1: + final = "()\n" + param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final + + return param_code + + + + def gen_first_gemm_1stage(b2b_num): + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + WarpFragmentA0 warp_frag_A0;\n\ + WarpFragmentB0 warp_frag_B0;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tighest latency requirement).\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\ + }\n\ + this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\ +\n\ + __syncthreads();\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + if(gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n" + + return accu_code + code + + + def gen_first_gemm_2stage(b2b_num): + + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + ++this->smem_iterator_A_;\n\ + ++this->smem_iterator_B0_;\n\ +\n\ + __syncthreads();\n\ +\n\ + // Pair of fragments used to overlap shared memory loads and math instructions\n\ + WarpFragmentA0 warp_frag_A0[2];\n\ + WarpFragmentB0 warp_frag_B0[2];\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tighest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\ +\n\ + // Write fragments to shared memory\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ +\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tighest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ + \n\ + ++this->smem_iterator_B0_;\n\ + ++this->smem_iterator_A_;\n\ + \n\ +\n\ + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\ + if (smem_write_stage_idx == 1) {\n\ + this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\ + this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\ + }\n\ + else {\n\ + this->warp_tile_iterator_A0_.add_tile_offset(\n\ + {0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset(\n\ + {-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\ + 0});\n\ + }\n\ +\n\ + smem_write_stage_idx ^= 1;\n\ + }\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + \n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + if (warp_mma_k == 0) {\n\ +\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + // Avoid reading out of bounds if this was the last loop iteration\n\ + if (gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n\ +\n\ + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\ + }\n\ + }\n" + return accu_code + code + + def gen_other_gemms_2stage(b2b_num): + + code = "" + + def gemm_teamplate(id): + code = "// " + str(id + 1) + " Gemm" + code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n" + + code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilouge_accu", id - 1) + ";\n" + code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \ + + helper.var_idx(", after_epilouge_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n" + + # FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilouge_accu", id - 1) + ");\n" + # FragmentB1 tb_frag_B1; + code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n" + # tb_frag_B1.clear(); + code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n" + # iterator_B1.load(tb_frag_B1); + code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++iterator_B1; + code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++this->smem_iterator_B1_; + code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n" + # __syncthreads(); + code += " " + "__syncthreads();\n" + # WarpFragmentA1 warp_frag_A1[2]; + code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n" + # WarpFragmentB1 warp_frag_B1[2]; + code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n" + # this->warp_tile_iterator_B1_.set_kgroup_index(0); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n" + # ++warp_tile_iterator_A1_; + code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # Operator1 warp_mma1; + code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n" + # smem_write_stage_idx = 1; + code += " " + "smem_write_stage_idx = 1;\n" + # int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n" + # if (gemm_k_iterations_1 <= 1) { + # iterator_B1.clear_mask(); + # } + code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \ + + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " +"}\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n" + # if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n" + # __syncthreads(); + code += " " + " " + " " + " " + "__syncthreads();\n" + # ++smem_iterator_B1_; + code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n" + # if (smem_write_stage_idx == 1) { + # smem_iterator_B1_.add_tile_offset({-Base::Stage, 0}); + # } + code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \ + + " " + " " + " " + " " +"}\n" + # else { + # this->warp_tile_iterator_B1_.add_tile_offset( + # {-Base::Stage * Policy1::kPartitionsK * + # Base::kWarpGemmIterations1, + # 0}); + # } + code += " " + " " + " " + " " + "else {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \ + + " " + " " + " " + " " + " " + "0});\n" \ + + " " + " " + " " + " " + "}\n" + + # smem_write_stage_idx ^= 1; + # } + code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \ + + " " + " " + " " + "}\n" + + # this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n" + # ++warp_tile_iterator_A1_; + code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # if (warp_mma_k == 0) { + # iterator_B1.load(tb_frag_B1); + # ++iterator_B1; + # if (gemm_k_iterations_1 <= 2) { + # iterator_B1.clear_mask(); + # } + # } + code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \ + + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \ + + " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \ + + " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " + " " + " " + " " + "}\n" \ + + " " + " " + " " + "}\n" + # warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum); + # } + # } + code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \ + + " " + " " + "}\n" \ + + " " + "}\n\n\n" + + return code + + for i in range (1, b2b_num): + clear_accu = "" + if i != b2b_num - 1: + clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n" + clear_accu += " " + helper.var_idx("accum", i) +".clear();\n" + code += clear_accu + gemm_teamplate(i) + + return code + + operator_code = " CUTLASS_DEVICE\n\ + void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n" + if first_use_1stage: + operator_code += gen_first_gemm_1stage(self.b2b_num) + else: + operator_code += gen_first_gemm_2stage(self.b2b_num) + operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n" + return operator_code + + def gen_construct_func(self): + name = self.gen_class_name + func_code = "CUTLASS_DEVICE\n" + func_code += name + "(\n" \ + + " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \ + + " " + "int thread_idx,\n" \ + + " " + "int warp_idx,\n" \ + + " " + "int lane_idx\n" \ + + "):\n" + func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \ + + " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n" + + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num - 1: + final = " {\n" + func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final + + func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + + func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n" + func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n" + + func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n" + + func_code += "}\n" + + return func_code + + def gen_member_func(self, first_use_1stage): + code = "public:\n" + code += self.gen_operator(first_use_1stage) + code += self.gen_construct_func() + + return code + + def gen_code(self, first_use_1stage): + + def gen_template_args(b2b_num): + template_param = [] + template_param.append(("typename", "Shape0")) + template_param.append(("typename", "IteratorA0")) + template_param.append(("typename", "SmemIteratorA0")) + template_param.append(("typename", "IteratorB0")) + template_param.append(("typename", "SmemIteratorB0")) + + for i in range(1, b2b_num): + template_param.append(("typename", helper.var_idx("Shape", i))) + template_param.append(("typename", helper.var_idx("FragmentIteratorA", i))) + template_param.append(("typename", helper.var_idx("IteratorB", i))) + template_param.append(("typename", helper.var_idx("SmemIteratorB", i))) + + template_param.append(("typename", "ElementC")) + template_param.append(("typename", "LayoutC")) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("OutputOp", i))) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i))) + + for i in range(0, b2b_num): + template_param.append(("typename", helper.var_idx("Policy", i))) + for i in range(0, b2b_num): + template_param.append((int, helper.var_idx("Stage", i))) + + template_param.append(("typename","TransformA0", "NumericArrayConverter")) + + for i in range(0, b2b_num): + cvtr = helper.var_idx("NumericArrayConverter" + template_param.append(("typename", helper.var_idx("TransformB", i), cvtr)) + + template_param.append(("typename", "Enable", "bool")) + + return template_param + + template_param = gen_template_args(self.b2b_num) + inheritance_code = "public B2bMmaBase<" + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num - 1): + inheritance_code += helper.var_idx("Stage", i) + "_, " + inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_" + inheritance_code += ">" + + code_body = "" + using_code= self.gen_using() + func_code = self.gen_member_func(first_use_1stage) + + code_body = using_code + func_code + + class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code) + + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + # print(code) + return code + + +class gen_b2b_mma_base: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dirs}cutlass/aligned_buffer.h\" +#include \"{cutlass_dirs}cutlass/arch/memory.h\" +#include \"{cutlass_dirs}cutlass/array.h\" +#include \"{cutlass_dirs}cutlass/cutlass.h\" +#include \"{cutlass_dirs}cutlass/gemm/gemm.h\" +#include \"{cutlass_dirs}cutlass/matrix_shape.h\" +#include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root) + return code + + def gen_shared_storage(self): + code = \ +" template< \n\ + typename Shape_,\n\ + typename Policy_,\n\ + int ThisStage_\n\ +>\n\ +class SharedStorage {\n\ +public:\n\ + using Shape = Shape_;\n\ + using Policy = Policy_;\n\ + static int const ThisStage = ThisStage_;\n\ + using Operator = typename Policy::Operator;\n\ + \ + using TensorRefA = TensorRef;\n\ + \ + /// Tensor reference to the B operand \n\ + using TensorRefB = TensorRef;\n\ +\n\ + /// Shape of the A matrix operand in shared memory \n\ + using ShapeA = MatrixShape;\n\ +\n\ + /// Shape of the B matrix operand in shared memory\n\ + using ShapeB =\n\ + MatrixShape;\n\ +\n\ + public:\n\ +\n\ + /// Buffer for A operand\n\ + AlignedBuffer operand_A;\n\ +\n\ + /// Buffer for B operand\n\ + AlignedBuffer operand_B;\n\ +\n\ + public:\n\ +\n\ + /// Returns a layout object for the A matrix\n\ + CUTLASS_DEVICE\n\ + static typename Operator::LayoutA LayoutA() {\n\ + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\ + }\n\ +\n\ + /// Returns a layout object for the B matrix\n\ + CUTLASS_HOST_DEVICE\n\ + static typename Operator::LayoutB LayoutB() {\n\ + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\ + }\n\ +\n\ + /// Returns a TensorRef to the A operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefA operand_A_ref() {\n\ + return TensorRefA{operand_A.data(), LayoutA()};\n\ + }\n\ +\n\ + /// Returns a TensorRef to the B operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefB operand_B_ref() {\n\ + return TensorRefB{operand_B.data(), LayoutB()};\n\ + }\n\ + CUTLASS_HOST_DEVICE\n\ + void * get_B_Shared_ptr() {\n\ + return operand_B.data();\n\ + }\n\ + };\n" + return code + + def gen_using_and_misc(self, b2b_num): + code_using = "" + for i in range(b2b_num): + code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n" + + for i in range(b2b_num): + code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n" + + for i in range(b2b_num): + code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\ + + helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\ + + helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n" + + code_misc = "" + for i in range(b2b_num): + code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n" + + code = code_using + code_misc + self.gen_shared_storage() + + for i in range(b2b_num): + code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n" + + def gen_union_shared_storage(b2b_num): + code = "" + for i in range(b2b_num): + code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n" + return code + + code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n" + + for i in range(b2b_num - 1): + code += helper.var_idx("void * C", i) + "_smm_ptr;\n" + + return code + + def gen_protected(self): + code = "\nprotected:\n" + code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n" + for i in range(self.b2b_num): + code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n" + return code + + def gen_public_member(self): + code = "\npublic:\n" + + code += "CUTLASS_DEVICE\n" + code += \ + "B2bMmaBase(\n" + \ + " B2bMmaSharedStorage & shared_storage,\n" + \ + " int thread_idx,\n" + \ + " int warp_idx,\n" + \ + " int lane_idx\n" + \ + "):\n" + \ + " warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n" + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num-1: + final = "\n" + + iterator = " warp_tile_iterator_B" + str(i) + "_" + shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()" + code += iterator + "(" + shared_storage + ", lane_idx)" + final + + + code += "{\n" + for i in range(self.b2b_num - 1): + code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n" + code += "}\n" + + return code + + def gen_code(self): + + tempalte_arg = [] + for i in range(self.b2b_num): + tempalte_arg.append(("typename", helper.var_idx("Shape", i))) + for i in range(self.b2b_num): + tempalte_arg.append(("typename", helper.var_idx("Policy", i))) + for i in range(self.b2b_num): + tempalte_arg.append((int, helper.var_idx("Stage", i))) + + + + code_body = self.gen_using_and_misc(self.b2b_num) + code_body += self.gen_protected() + code_body += self.gen_public_member() + + class_code = gen_ir.gen_template_class("B2bMmaBase", tempalte_arg, code_body) + + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + + return code + + +class gen_threadblock: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.file_dir = output_dir + "/threadblock/" + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_b2b_mma_piplined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + + def gen_code(self, first_use_1stage): + + base_code = self.gen_b2b_mma_base.gen_code() + print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_base.h", "w+") as f: + f.write(base_code) + pipeline_code = self.gen_b2b_mma_piplined.gen_code(first_use_1stage = first_use_1stage) + print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f: + f.write(pipeline_code) + default_code = self.gen_default_b2b_mma.gen_code() + print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_mma.h", "w+") as f: + f.write(default_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py new file mode 100644 index 00000000..06b73b07 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -0,0 +1,456 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_turing_impl: + def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.class_name = gen_class_name + self.gen_class_name = gen_class_name + "_turing_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_using(self): + code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + ";" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + + if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code += code_this + code += "typename b2b_gemm::Arguments arguments{\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("problem_size_", i) + ",\n" + + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n" + + for i in range(self.b2b_num): + + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n" + + + for i in range(self.b2b_num): + code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code += "},\n" + code += " " + "Batch};\n\n" + + code += " " "b2b_gemm gemm_op;\n" + code += " " + "gemm_op.initialize(arguments);\n" + return code + "\n" + + + + def gen_run(self): + code = " " + "gemm_op(stream);\n" + + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + + if self.b2b_num == 1: + code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta + code_body += self.gen_turing_unfused.gen_initialize() + code_body += self.gen_turing_unfused.gen_run() + else: + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + + code = self.gen_wrapper() + helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_volta_turing_fuse_act_impl: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + "_volta_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def perf_tiling(self, layer_mnk): + mnk = layer_mnk[:] + block_tile = mnk[:] + block_tile[2] = 32 # force the K tile to be 32 + + # M tile gen + block_tile[0] = 32 + + # N tile gen + if mnk[1] > 128: + block_tile[1] = 256 + elif mnk[1] > 64: + block_tile[1] = 128 + elif mnk[1] > 32: + block_tile[1] = 64 + else : + block_tile[1] = 32 + + warp_tile = block_tile[:] + if block_tile[1] == 256: + warp_tile[1] = 64 + elif block_tile[1] == 128: + warp_tile[1] = 32 + elif block_tile[1] == 64: + warp_tile[1] = 32 + else : + warp_tile[1] = 32 + + warp_tile[0] = 32 + + return block_tile, warp_tile + + + def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp): + epilogue_setted_type = epilogue_tp + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">" + + return epilogue_str + + def gen_using(self, volta = True): + code_using = "" + volta_arch = "cutlass::arch::Sm70" + volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>" + + turing_arch = "cutlass::arch::Sm75" + turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>" + + arch = "" + tc = "" + if volta: + arch = volta_arch + tc = volta_tc + else: + arch = turing_arch + tc = turing_tc + + for i in range(self.b2b_num): + + k = self.fuse_gemm_info[i]['mnk'][2] + + k_mod_8 = k % 4 + ab_ldm = 1 + if k_mod_8 == 0: + ab_ldm = 8 + elif k_mod_8 == 4: + ab_ldm = 4 + elif k_mod_8 == 2 or k_mod_8 == 6: + ab_ldm = 2 + + block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk']) + + this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n" + this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n" + this_gemm_config += " " + arch + ",\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n" + this_gemm_config += " " + tc + ",\n" + this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n" + this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n" + this_gemm_config += " " + "2,\n" + this_gemm_config += " " + str(ab_ldm) + ",\n" + this_gemm_config += " " + str(ab_ldm) + ">;\n" + + code_using += this_gemm_config + "\n" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = k_str + ldmB = k_str + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code_this += " },\n" + code_this += " " + "Batch};\n" + + code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n" + + code += code_this + "\n" + return code + "\n" + + + def gen_run(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n" + + code += code_this + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + code = self.gen_wrapper() + helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_one_API: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_CUTLASS_irrelevant_API(self): + code = "" + code += "#include \n" + code += "#include \n" + + param_name = "Fused" + str(self.b2b_num) + "xGemm_" + for i in range(self.b2b_num): + param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_" + param_name += "Params" + params = "" + params += " " + "int M;\n" + params += " " + "int K0;\n" + params += " " + "int Batch;\n" + params += " " + "const void* A0;\n" + for i in range(self.b2b_num): + params += " " + "const void* " + helper.var_idx("B", i) + ";\n" + params += " " + "const void* " + helper.var_idx("C", i) + ";\n" + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + params += " " + arg_tp + " " + arg_name + ";\n" + params += " " + "void* " + helper.var_idx("D", i) + ";\n" + code += ir.gen_struct(param_name, params) + code += "using Param = " + param_name + ";\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n" + + + return code + + def gen_one_api(self): + code = "" + code += "/* Auto Generated code - Do not edit.*/\n" + code += "#include \"cutlass_irrelevant.h\"\n" + code += "#include \"api.h\"\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n" + + code += " " + "if (sm == 70) \n" + code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else if(sm >= 75) \n" + code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else assert(0);\n" + code += "}\n" + return code + + def gen_code(self): + + turing_code = self.gen_turing.gen_wrapper() + volta_code = self.gen_volta.gen_wrapper() + cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API() + + one_api_code = self.gen_one_api() + with open(self.output_dir + "one_api.cu", "w+") as f: + f.write(one_api_code) + + helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code) + + helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py new file mode 100644 index 00000000..645c6615 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +import gen_turing_and_volta as gen_basic + + +class gen_verify: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.name = gen_class_name + "_verify" + self.b2b_num = len(fuse_gemm_info) + self.params = [] + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.seperate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + self.gen_params() + self.output_dir = output_dir + + + def gen_code(self): + code = "" + code += self.user_header_file + code += self.seperate_cutlass.gen_using(False) #False -> Turing, True -> Volta + + code_body = "" + for i in range(self.b2b_num): + code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n" + + code_body += self.seperate_cutlass.gen_run() + + code += ir.gen_func(self.name, self.params, code_body) + helper.write_2_headfile("cutlass_verify.h", self.output_dir, code) + + + def gen_params(self): + for i in range(self.b2b_num): + self.params.append( + ( + helper.var_idx("typename Gemm", i)+ "::Arguments", + helper.var_idx("Arguments_", i) + ) + ) + + + def get_params(self, declartion = True): + code = "" + if declartion: + for param in self.params: + code += param[0] + " " + param[1] + ";\n" + + return code + + + def gen_initialize(): + code = "" + initialize_code = self.seperate_cutlass.gen_initialize() + + code = ir.gen_func("initialize", [[]]) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh b/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh new file mode 100755 index 00000000..c7d40dba --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +NUM_ARGS=3 +if [ $# -ne $NUM_ARGS ]; then + echo "Usage: $0 " + echo " config_file: JSON file containing configuration to run" + echo " output_directory: directory to store results" + echo " cutlass_directory: directory containing cutlass source" + exit 1 +fi + +config_file=$1 +output_dir=$2 +cutlass_dir=$3 + +python3 gen_all_code.py \ + --config-file $config_file \ + --gen-name FusedMultiGemmForward \ + --output-dir $output_dir \ + --cutlass-dir $cutlass_dir diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py new file mode 100644 index 00000000..9c9c3779 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py @@ -0,0 +1,135 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +def type_2_cutlass_type(input_type = "fp16"): + # float point type + if input_type == "fp32": + return "float" + if input_type == "bf16": + return "cutlass::bfloat16_t" + if input_type == "fp16": + return "cutlass::half_t" + + # integer type + if(input_type == "int32"): + return "int32_t" + if(input_type == "int8"): + return "int8_t" + + if input_type == 'Row': + return 'cutlass::layout::RowMajor' + if input_type == 'Col': + return 'cutlass::layout::ColumnMajor' + +def cvt_2_cutlass_shape(gemm_shape): + # gemm shape + if len(gemm_shape) == 3: + val = "cutlass::gemm::GemmShape<" \ + + str(gemm_shape[0]) + ", " \ + + str(gemm_shape[1]) + ", " \ + + str(gemm_shape[2]) + ">" + return val + + +def write_2_headfile(filename, file_dir, string): + with open(file_dir + filename, 'w') as f: + f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string) + +def var_idx(varaiable, index): + return varaiable + str(index) + + +def list_2_string(input_list, ): + rtn_string = "" + + cnt = 0 + + for element in input_list: + final = ", \n" + if cnt == len(input_list) - 1: + final = "\n" + cnt += 1 + rtn_string += str(element) + final + + return rtn_string + + +def get_epilouge_info(layer_info): + return layer_info['epilogue'] + +def get_epilogue_tp(layer_info): + epilogue_info = get_epilouge_info(layer_info) + return epilogue_info['tp'] + +def get_epilogue_add_bias_or_not(layer_info): + epilogue_info = get_epilouge_info(layer_info) + return epilogue_info['bias']['addbias'] + +def get_epilogue_add_bias_tp(layer_info): + epilogue_info = get_epilouge_info(layer_info) + return epilogue_info['bias']['bias_tp'] + +def get_epilogue_args(layer_info): + epilogue_info = get_epilouge_info(layer_info) + return epilogue_info['args'] + +def get_epilogue_bias_shape(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + if bias_tp == 'mat': + mn_shape[0] = 'M' + return mn_shape + elif bias_tp == 'vec': + mn_shape[0] = 1 + return mn_shape + else: + assert(0) + +def get_epilogue_bias_ldm(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + c_layout = layer_info['C_format'].lower() + + if c_layout != 'row': + assert(0) + + if bias_tp == 'mat': + return mn_shape[1] + elif bias_tp == 'vec': + return 0 + else: + assert(0) + +def get_epilogue_compute_tp(layer_info): + return layer_info['Acc_tp'] diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py new file mode 100644 index 00000000..d6d12944 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py @@ -0,0 +1,67 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os + +class replace_fix_impl: + def __init__(self, src_dir, dst_dir, cutlass_deps_root): + self.src_dir = src_dir + self.dst_dir = dst_dir + self.cutlass_deps_root = cutlass_deps_root + + + + def gen_code(self): + for sub_dir in os.walk(self.src_dir): + files_in_sub_dir = sub_dir[2] + + src_dirs = sub_dir[0] + output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):] + + if not os.path.exists(output_dirs): + os.mkdir(output_dirs) + + for f in files_in_sub_dir: + with open(src_dirs +"/" + f, 'r') as current_file: + output_lines = [] + lines = current_file.readlines() + + for line in lines: + if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"): + new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):] + # print(new_line) + output_lines.append(new_line) + else: + output_lines.append(line) + + with open(output_dirs + "/" + f, "w+") as dest_file: + dest_file.writelines(output_lines) diff --git a/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h new file mode 100644 index 00000000..4eb34fef --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +template +__device__ +T add(T const & a, T const &b){ + return (a + b); +} + +template <> +__device__ +half2 add(half2 const & a, half2 const &b){ + return (__hadd2(a,b)); +} + +template +struct RELU{ + __device__ + T operator()(T const & a){ + return a > T(0) ? a : T(0); + } + __device__ + half2 operator()(half2 const & a){ + float2 a_fp32x2 = __half22float2(a); + a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; + a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; + if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) + printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); + return __float22half2_rn(a_fp32x2); + } +}; + +template +struct LEAKY_RELU{ + __device__ + T operator()(T const & a, T const & scale = half(1)){ + return a > T(0) ? a : scale * a; + } + __device__ + half2 operator()(half2 const & a, half const & scale = half(1)){ + half2 zero = __half2half2(half(0)); + half2 gt_zero = __hge2(a, zero); + half2 le_zero = __hle2(a, zero); + + + half2 scale_f16x2 = __half2half2(scale); + half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); + return __hmul2(a, mask_scale_f16x2); + } +}; + +template +__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + LEAKY_RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); + } + + } +} + + + +template +__global__ void leaky_and_activation(half* inout, half scale){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + LEAKY_RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); + } + + } +} + + + +template +void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ + + dim3 grid(m, b); + if (bias == nullptr) + leaky_and_activation<<>>(inout, scale); + else + leaky_and_activation<<>>(inout, bias, scale, mat_bias); +} + +template +__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); + } + + } +} + + + +template +__global__ void relu_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); + } + + } +} + + + +template +void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + relu_and_activation<<>>(inout); + else + relu_and_activation<<>>(inout, bias, mat_bias); +} + + +template +__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); + } + + } +} + +template +__global__ void identity_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); + } + + } +} + +template +void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + identity_and_activation<<>>(inout); + else + identity_and_activation<<>>(inout, bias, mat_bias); +} diff --git a/examples/44_multi_gemm_ir_and_codegen/utils.h b/examples/44_multi_gemm_ir_and_codegen/utils.h new file mode 100644 index 00000000..36b9bf21 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/utils.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#define TI(tag) \ + cudaEvent_t _event_start_ ##tag; \ + cudaEvent_t _event_end_ ##tag; \ + float _event_time_ ##tag; \ + cudaEventCreate(& _event_start_ ##tag); \ + cudaEventCreate(& _event_end_ ##tag); \ + cudaEventRecord(_event_start_ ##tag); + +#define TO(tag, str, times) \ + cudaEventRecord(_event_end_ ##tag); \ + cudaEventSynchronize(_event_end_ ##tag); \ + cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ + float _event_time_once_ ##tag = _event_time_ ##tag / times; \ + printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ + cudaDeviceSynchronize(); \ + printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); + +template +struct memory_unit{ + T* host_ptr; + T* device_ptr; + int size_bytes; + int elements; + void h2d(){ + cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); + } + void d2h(){ + cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); + } + void free_all(){ + free(host_ptr); + cudaFree(device_ptr); + } + memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ + host_ptr = (T*) malloc(elements_ * sizeof(T)); + cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); + } + void init(int abs_range = 1){ + for(int i = 0; i < elements; i++){ + host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); + } + h2d(); + } +}; + +template +int check_result(T * a, T * b, int N){ + int cnt = 0; + for(int i = 0; i < N; i ++){ + float std = float(a[i]); + float my = float(b[i]); + + if(abs(std - my) / abs(std) > 1e-2) + { + // printf("my: %f , std: %f\n", my, std); + cnt++; + } + + } + printf("total err: %d / %d\n", cnt, N); + return cnt; +} diff --git a/examples/43_dual_gemm/CMakeLists.txt b/examples/45_dual_gemm/CMakeLists.txt similarity index 99% rename from examples/43_dual_gemm/CMakeLists.txt rename to examples/45_dual_gemm/CMakeLists.txt index 8433b1af..0d550664 100644 --- a/examples/43_dual_gemm/CMakeLists.txt +++ b/examples/45_dual_gemm/CMakeLists.txt @@ -30,7 +30,7 @@ cutlass_example_add_executable( - 43_dual_gemm + 45_dual_gemm dual_gemm.cu ) diff --git a/examples/43_dual_gemm/device/dual_gemm.h b/examples/45_dual_gemm/device/dual_gemm.h similarity index 100% rename from examples/43_dual_gemm/device/dual_gemm.h rename to examples/45_dual_gemm/device/dual_gemm.h diff --git a/examples/43_dual_gemm/dual_gemm.cu b/examples/45_dual_gemm/dual_gemm.cu similarity index 100% rename from examples/43_dual_gemm/dual_gemm.cu rename to examples/45_dual_gemm/dual_gemm.cu diff --git a/examples/43_dual_gemm/dual_gemm_run.h b/examples/45_dual_gemm/dual_gemm_run.h similarity index 100% rename from examples/43_dual_gemm/dual_gemm_run.h rename to examples/45_dual_gemm/dual_gemm_run.h diff --git a/examples/43_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h similarity index 100% rename from examples/43_dual_gemm/kernel/dual_gemm.h rename to examples/45_dual_gemm/kernel/dual_gemm.h diff --git a/examples/43_dual_gemm/test_run.h b/examples/45_dual_gemm/test_run.h similarity index 100% rename from examples/43_dual_gemm/test_run.h rename to examples/45_dual_gemm/test_run.h diff --git a/examples/43_dual_gemm/thread/left_silu_and_mul.h b/examples/45_dual_gemm/thread/left_silu_and_mul.h similarity index 100% rename from examples/43_dual_gemm/thread/left_silu_and_mul.h rename to examples/45_dual_gemm/thread/left_silu_and_mul.h diff --git a/examples/43_dual_gemm/threadblock/dual_epilogue.h b/examples/45_dual_gemm/threadblock/dual_epilogue.h similarity index 100% rename from examples/43_dual_gemm/threadblock/dual_epilogue.h rename to examples/45_dual_gemm/threadblock/dual_epilogue.h diff --git a/examples/43_dual_gemm/threadblock/dual_mma_base.h b/examples/45_dual_gemm/threadblock/dual_mma_base.h similarity index 100% rename from examples/43_dual_gemm/threadblock/dual_mma_base.h rename to examples/45_dual_gemm/threadblock/dual_mma_base.h diff --git a/examples/43_dual_gemm/threadblock/dual_mma_multistage.h b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h similarity index 100% rename from examples/43_dual_gemm/threadblock/dual_mma_multistage.h rename to examples/45_dual_gemm/threadblock/dual_mma_multistage.h diff --git a/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt b/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt new file mode 100644 index 00000000..e43f7e77 --- /dev/null +++ b/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 46_depthwise_simt_conv2dfprop + depthwise_simt_conv2dfprop.cu + ) + diff --git a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu new file mode 100644 index 00000000..5e4164e2 --- /dev/null +++ b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu @@ -0,0 +1,672 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +This example shows how to run depthwise 2d convolution kernels using functions and data structures +provided by CUTLASS using SIMT instruction; + +There are 3 types of implementations of depthwise 2d convoltion + 1. kAnalytic + Implicit gemm 2d convoltion algorithm. + 2. kOptimized + An optimized algorithm and supports arbitrary stride and dilation. + 3. kFixedStrideDilation + An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do +more optimizations. + +In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter +size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If +in this case, please use kOptimized. + +For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve +better perf, when the output tensor size is large, splitk should be enabled to achieve better perf. + +In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d +convolution kernel. +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = cutlass::half_t; // Data type of accumulator +using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementOutput = cutlass::half_t; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNHWC; +using LayoutInputB = cutlass::layout::TensorNHWC; +using LayoutOutput = cutlass::layout::TensorNHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassSimt; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm60; + +// This code section describes the groups a thread block will compute +constexpr int groups_per_cta = 64; + +// This code section describes the output tile a thread block will compute +using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + +// This code section describes the filter shape +using FilterShape = cutlass::MatrixShape<3, 3>; + +// Threadblock tile shape +using ThreadblockShape = + cutlass::gemm::GemmShape; + +// This code section describes tile size a warp will computes +// WarpShape::kM = P * Q the warps would process +// WarpShape::kN = groups_per_cta that the warps would process +// WarpShape::kK = filter_size that the warps would process +using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + +// Number of pipelines you want to use +constexpr int NumStages = 4; + +// This code section describe iterator algorithm selected is kFixedStrideDilation +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; +using StrideShape = cutlass::MatrixShape<1, 1>; +using DilationShape = cutlass::MatrixShape<1, 1>; + +constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation. + +using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + +using Direct2dConv = cutlass::conv::device::DirectConvolution; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + int groups; + int splitk; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + std::string tag; + + Options() + : help(false), + input_size(1, 128, 128, 32), + filter_size(32, 3, 3, 1), + groups(32), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(true), + iterations(20), + save_workspace(false), + alpha(1), + beta(0), + splitk(1) {} + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { + // misaligned tensors + return false; + } + + // depthwise conv + if (groups != input_size.c()) { + return false; + } + + if (filter_size.n() != groups) { + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) { + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("save-workspace")) { + save_workspace = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + + cmd.get_cmd_line_argument("g", groups); + + filter_size.c() = 1; + filter_size.n() = input_size.c(); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("splitk", splitk); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + int32_t padding_h = filter_size.h() / 2; + int32_t padding_w = filter_size.w() / 2; + padding = {padding_h, padding_h, padding_w, padding_w}; + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + out << "41_depthwise_gemm_fprop example\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" + << " forward convolution on tensors of layout NHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n= Input tensor extent N\n" + << " --h= Input tensor extent H\n" + << " --w= Input tensor extent W\n" + << " --c= Input tensor extent C\n" + << " --k= Filter extent K\n" + << " --r= Filter extent R\n" + << " --s= Filter extent S\n\n" + << " --g= Groups\n\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --splitk= Enable splitK\n\n" + << " --ref-check If set (true), reference check on the host is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --save-workspace If set, workspace is written to a text file.\n" + << " --tag= String to replicate across the first column in the results " + "table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=32 " + "--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n" + << "$ ./examples/45_depthwise_simt_conv2dfprop/45_depthwise_simt_conv2dfprop --n=1 " + "--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Number of multiply-adds = NPQK * CRS + int64_t fmas = + output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result() + : runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) {} + + static std::ostream &print_header(std::ostream &out, Options const &options) { + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs"; + + return out; + } + + std::ostream &print(std::ostream &out, int idx, Options const &options) { + if (!options.tag.empty()) { + out << options.tag << ","; + } + + cutlass::Tensor4DCoord output_size = options.output_size(); + out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," + << options.input_size.w() << "," << options.input_size.c() << "," + + << options.filter_size.n() << "," << options.filter_size.h() << "," + << options.filter_size.w() << "," + + << options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column() + << "," + + << options.dilation.row() << "," << options.dilation.column() << "," + + << options.splitk << "," + + << runtime_ms << "," << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one testcase +Result profile_convolution(Options const &options) { + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor tensor_b_transpose(options.filter_size); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0); + + // Fill tensor D on host with zeros + cutlass::reference::host::TensorFill(tensor_d.host_view()); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_b_transpose.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split P*Q into multiple CTA + int split_k_slices = options.splitk; + + // Construct Conv2dProblemSize with user defined output size + cutlass::conv::Conv2dProblemSize problem_size(options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices, + options.groups); + + // Construct Direc2dConv::Argument structure with conv2d + // problem size, data pointers, and epilogue values + typename Direct2dConv::Arguments arguments{problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + tensor_b_transpose.device_ref()}; + + // + // Initialize CUTLASS Convolution + // + + Direct2dConv implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on host...\n"; + + // Compute with reference implementation + cutlass::reference::host::Conv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter >(problem_size, + tensor_a.host_ref(), + tensor_b.host_ref(), + tensor_c.host_ref(), + tensor_ref_d.host_ref(), + options.alpha, + options.beta); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + + bool passed = + cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } else { + result.reference_check = cutlass::Status::kInvalid; + } + + if (options.save_workspace) { + std::stringstream ss; + + ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() + << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" + << options.filter_size.n() << "x" << options.filter_size.h() << "x" + << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; + + std::ofstream output_workspace(ss.str()); + + output_workspace << "Input = \n" + << tensor_a.host_view() << "\n\n" + << "Filters = \n" + << tensor_b.host_view() << "\n\n"; + + if (options.reference_check) { + output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; + } + + output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + + std::cout << "Results written to '" << ss.str() << "'." << std::endl; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + cudaEvent_t events[2]; + + for (auto &event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) + << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + bool notSupported = false; + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major >= 6)) { + std::cerr << "Run on a machine with compute capability at least 60." << std::endl; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index bd75a742..a4c132c1 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -119,9 +119,11 @@ foreach(EXAMPLE 37_gemm_layernorm_gemm_fusion 38_syr2k_grouped 39_gemm_permute - 41_multi_head_attention - 42_fused_multi_head_attention - 43_dual_gemm + 41_fused_multi_head_attention + 42_ampere_tensorop_group_conv + 43_ell_block_sparse_gemm + 45_dual_gemm + 46_depthwise_simt_conv2dfprop ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/aligned_buffer.h b/include/cutlass/aligned_buffer.h index 751e72a4..f869d388 100644 --- a/include/cutlass/aligned_buffer.h +++ b/include/cutlass/aligned_buffer.h @@ -80,9 +80,9 @@ public: typedef value_type *pointer; typedef value_type const * const_pointer; - using ArrayType = Array; - using reference = typename ArrayType::reference; - using const_reference = typename ArrayType::const_reference; + using Array = Array; + using reference = typename Array::reference; + using const_reference = typename Array::const_reference; public: diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 578e6c14..c48ebcc6 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -85,6 +85,10 @@ struct Sm86 { static int const kMinComputeCapability = 86; }; +struct Sm90 { + static int const kMinComputeCapability = 90; +}; + /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index a98114a1..b5e5c9b0 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -451,7 +451,7 @@ template <> CUTLASS_DEVICE void shared_store<16>(uint32_t ptr, void const *src) { uint4 const *dst_u128 = reinterpret_cast(src); - asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr), "r"(dst_u128->x), diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index ce3e02f3..a050c7a3 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -223,4 +223,6 @@ struct SparseMma; #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/arch/mma_sparse_sm80.h" +#include "cutlass/arch/mma_sm90.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 22b633fb..243ec7b1 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -1065,7 +1065,7 @@ struct Mma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + asm volatile("_mma.m8n8k32.row.col.u4.s4.sat {%0,%1}, %2, %3, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); @@ -1247,7 +1247,8 @@ struct Mma< ) const { #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED)) using WmmaFragmentA = nvcuda::wmma::fragment< nvcuda::wmma::matrix_a, Shape::kM, @@ -1279,6 +1280,7 @@ struct Mma< nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + #else CUTLASS_UNUSED(a); @@ -1289,14 +1291,7 @@ struct Mma< #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -#else - CUTLASS_UNUSED(a); - CUTLASS_UNUSED(b); - CUTLASS_UNUSED(c); - CUTLASS_UNUSED(d); - assert(0); #endif - } }; diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index 36005174..d2ad93bd 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -2156,6 +2156,7 @@ struct Mma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); + asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, " diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h new file mode 100644 index 00000000..45f334b7 --- /dev/null +++ b/include/cutlass/arch/mma_sm90.h @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) +#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#define CUTLASS_ARCH_MMA_SM90_ENABLED +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x4 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[1]), + "d"(B[0]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 347002f0..4f9c9185 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -35,7 +35,9 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/functional.h" #include "cutlass/numeric_types.h" +#include "cutlass/half.h" namespace cutlass { @@ -493,6 +495,9 @@ public: }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Factories //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -533,8 +538,1869 @@ Array make_Array(Element x, Element y, Element z, Element w) { return m; } + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct absolute_value_op< Array > { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + absolute_value_op scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; +template +struct minus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct multiplies> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct divides> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct negate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + negate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); + } + + return result; + } +}; + + +template +struct conjugate > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + + conjugate conj_op; + + Array ca; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + ca[i] = conj_op(a[i]); + } + return ca; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations targeting SIMD instructions in device code. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs; + } + #endif + + return result; + } +}; + +template +struct minus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs; + } + #endif + + return result; + } +}; + +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmul( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmul( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs; + } + #endif + + return result; + } +}; + +template +struct divides> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs; + } + #endif + + return result; + } +}; + +template +struct negate> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hneg2(source_ptr[i]); + } + + if (N % 2) { + half_t x = lhs[N - 1]; + __half lhs_val = -reinterpret_cast<__half const &>(x); + result[N - 1] = reinterpret_cast(lhs_val); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = -lhs[i]; + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); + } + + if (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma_relu( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a, b[i], c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b, c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); + } + + if (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c), half_t(0)); + } + #endif + + return result; + } +}; + +template +struct minimum> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmin( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs[i] < lhs ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmin( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (rhs < lhs[i] ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + +template +struct maximum> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); + } + + if (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmax( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs < rhs[i] ? rhs[i] : lhs); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + } + + if (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmax( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = (lhs[i] < rhs ? rhs : lhs[i]); + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + bfloat16_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned a_packed = static_cast(a.raw()); + a_packed = (a_packed | (a_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) + ); + } + + if (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + + +/// bit_and +template +struct bit_and> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] & b_data[i]); + } + + return result; + } +}; + + +/// bit_or +template +struct bit_or> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] | b_data[i]); + } + + return result; + } +}; + + +/// bit_not +template +struct bit_not> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (~a_data[i]); + } + + return result; + } +}; + + +/// bit_xor +template +struct bit_xor> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] ^ b_data[i]); + } + + return result; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operator overloads +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs, Array const &rhs) { + minus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs) { + negate> op; + return op(lhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(T lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, T rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator/(Array const &lhs, Array const &rhs) { + divides> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(T a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, T b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, T c) { + multiply_add> op; + return op(a, b, c); +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -545,6 +2411,8 @@ Array make_Array(Element x, Element y, Element z, Element w) { namespace cutlass { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// AlignedArray //////////////////////////////////////////////////////////////////////////////////////////////////// /// Aligned array type diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h new file mode 100644 index 00000000..dbdb9cbc --- /dev/null +++ b/include/cutlass/barrier.h @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CTA-wide semaphore for inter-CTA synchronization. +struct Barrier +{ + +public: + + /// Flag type + using T = int; + + /// Initial flag value + static const T INIT = 0; + + +protected: + + /// Load flag, as a strong operation (int specialization) + CUTLASS_DEVICE + static int ld_strong(int *ptr) + { + int state = 0; + +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + asm volatile ("ld.global.relaxed.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#else + asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif // (__CUDA_ARCH__ >= 700) + + return state; + } + + /// Store flag, as a strong operation (int specialization) + CUTLASS_DEVICE + static void st_strong(int *ptr, int val) + { +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + asm volatile ("st.global.relaxed.gpu.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#else + asm volatile ("st.cg.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#endif // (__CUDA_ARCH__ >= 700) + } + + + /// Reduce into flag, with release pattern (int specialization) + CUTLASS_DEVICE + static void red_release(int *ptr, int val) + { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + asm volatile ("red.release.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#else + __threadfence(); + atomicAdd(ptr, val); +#endif // (__CUDA_ARCH__ >= 700) +#endif + } + + +public: + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) + { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_strong(flag_ptr) < count) {} + } + + __syncthreads(); +#endif + } + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) + { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_strong(flag_ptr) != val) {} + } + + __syncthreads(); +#endif + } + + /// Uses thread[0] to wait for the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(atomicCAS(flag_ptr, val, 0) != val) {} + } + + __syncthreads(); +#endif + } + + /// Increment the arrival count for a flag + CUTLASS_DEVICE + static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx) + { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + __syncthreads(); + + if (thread_idx == 0) { + red_release(flag_ptr, 1); + } +#endif + } + + + /// Increment the arrival counts for a range of flags + CUTLASS_DEVICE + static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1) + { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) + int flag_idx = first_flag_idx + thread_idx; + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + // Barrier to make sure all other threads in block have written their data + __syncthreads(); + + // Select threads increment their flags + if (thread_idx < count) { + red_release(flag_ptr, 1); + } +#endif + } +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index e4d20efc..50f1c236 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -35,7 +35,9 @@ */ #pragma once -#if !defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else #include #include #include @@ -71,8 +73,7 @@ struct alignas(2) bfloat16_t { } /// Default constructor - CUTLASS_HOST_DEVICE - bfloat16_t() : storage(0) { } + bfloat16_t() = default; /// Floating-point conversion - round toward nearest CUTLASS_HOST_DEVICE diff --git a/include/cutlass/blas3.h b/include/cutlass/blas3.h index 3c2df6dd..7736cce8 100644 --- a/include/cutlass/blas3.h +++ b/include/cutlass/blas3.h @@ -40,6 +40,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/coord.h" +#include "cutlass/complex.h" #include "cutlass/functional.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/block_striped.h b/include/cutlass/block_striped.h new file mode 100644 index 00000000..2ffd59b1 --- /dev/null +++ b/include/cutlass/block_striped.h @@ -0,0 +1,259 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, + statically-sized array types to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/wmma_array.h" +#include "cutlass/functional.h" +#include "cutlass/complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// AccessWidth +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit +template < + typename T, + int Limit> +struct AccessWidth +{ + // Inductive case + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes, /// Template induction variable + bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes + ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> + struct Detail + { + static const int value = Detail::value; + }; + + // Base case (ObjectBytes is not an even multiple of AlignBytes) + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes> /// Template induction variable + struct Detail + { + static const int value = AlignBytes / 2; + }; + + /// The maximal power-of-two that evenly divides the size of T + static const int value = Detail< + (int) sizeof(T), + 1>::value; +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// StripedAccessType +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Default specialization. Striping granularity is type T.) +template < + typename T, /// Data type + int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) + AccessWidth::value> +struct alignas(TransferBytes) StripedAccessType : public T +{}; + + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) +template < + typename T, /// Array element type + int N, /// Number of elements in array + bool RegisterSized, /// T is register-sized + int TransferBytes> /// Data access width +struct StripedAccessType< + Array, + TransferBytes> +: public AlignedArray< + T, // Element type of StripedAccessType + __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType + TransferBytes> // Alignment of StripedAccessType +{}; + + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) +template< + typename Use, + int m, + int n, + int k, + typename ElementT, + typename Layout, + int kFragments, + int TransferBytes> +struct StripedAccessType< + WmmaFragmentArray, kFragments>, + TransferBytes> +: public AlignedArray< + ElementT, + __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), + TransferBytes> +{}; + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStriped +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Utility for performing block-striped access (load, store) of trivially-copyable, +/// statically-sized array types to global memory +template < + int BlockThreads, + typename ArrayT, + typename T, + typename AccessT = StripedAccessType > +struct BlockStriped +{ + /// Number of striped accesses + static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); + static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); + + /// Load + CUTLASS_DEVICE + static void load(ArrayT &data, T *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_data[i] = access_input[(BlockThreads * i) + thread_idx]; + } + } + + /// Load & Add + CUTLASS_DEVICE + static void load_add(ArrayT &data, T *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + plus add; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) + { + access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); + } + } + + /// Store + CUTLASS_DEVICE + static void store(T *ptr, const ArrayT &data, int thread_idx) + { + AccessT *access_output = reinterpret_cast(ptr); + const AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_output[(BlockThreads * i) + thread_idx] = access_data[i]; + } + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStripedReduce +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Default specialization) +template < + int BlockThreads, + typename ArrayT, + typename T> +struct BlockStripedReduce : BlockStriped +{ + /// Reduce + CUTLASS_DEVICE + static void reduce(T *ptr, const ArrayT &data, int thread_idx) + { + cutlass::red reduce; + const T *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { + reduce(ptr + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Specialization for half_t. Uses half2 vectorized-reduction.) +template < + int BlockThreads, + typename ArrayT> +struct BlockStripedReduce : BlockStriped +{ + static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); + + /// Reduce + CUTLASS_DEVICE + static void reduce(half_t *ptr, const ArrayT &data, int thread_idx) + { + cutlass::red reduce; + half2 *access_output = reinterpret_cast(ptr); + const half2 *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) + { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +} // namespace cutlass + diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 10fe46a4..c101908d 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -32,6 +32,8 @@ #include +#include + #if defined(__CUDACC_RTC__) #include #else @@ -39,6 +41,7 @@ #endif #include "cutlass/cutlass.h" +#include "cutlass/functional.h" #include "cutlass/half.h" #include "cutlass/real.h" @@ -53,8 +56,10 @@ namespace cutlass { -////////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Enumeraed type describing a transformation on a complex value. enum class ComplexTransform { kNone, @@ -147,15 +152,18 @@ class complex // Methods // -/// Constructor - CUTLASS_HOST_DEVICE - complex(T r = T(0)) : _real(r), _imag(T(0)) {} + /// Default constructor + complex() = default; -/// Constructor + /// Constructor + CUTLASS_HOST_DEVICE + complex(T r) : _real(r), _imag(T(0)) {} + + /// Constructor CUTLASS_HOST_DEVICE complex(T r, T i) : _real(r), _imag(i) {} - // -/// Constructor + + /// Constructor template CUTLASS_HOST_DEVICE complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} @@ -197,6 +205,24 @@ class complex return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); } + /// Reduction into memory address. Components may update out of order. + template + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + cutlass::red reduce; + reduce(&ptr->_real, _real); + reduce(&ptr->_imag, _imag); + } + + /// Reduction into memory address. Components may update out of order. (Half specialization) + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + half2 *h2_ptr = reinterpret_cast(ptr); + half2 h2_data = reinterpret_cast(*this); + cutlass::red reduce; + reduce(h2_ptr, h2_data); + } + /// Subtraction template CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { @@ -506,13 +532,14 @@ CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) /// Partial specialization for complex-valued type. template -struct RealType< complex > { +struct RealType< complex > +{ using Type = T; /// Number of elements static int const kExtent = 2; -CUTLASS_HOST_DEVICE + CUTLASS_HOST_DEVICE static complex from_real(double x) { return complex(static_cast(x)); } @@ -550,6 +577,127 @@ struct is_complex> { static bool const value = true; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()); + Output y_i = Output(lhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b.real(); + real += -a.imag() * b.imag(); + imag += a.real() * b.imag(); + imag += a.imag () * b.real(); + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, T, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + T const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b; + imag += a.imag () * b; + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + T const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a * b.real(); + imag += a * b.imag(); + + return complex{ + real, + imag + }; + } +}; + +/// Conjugate +template +struct conjugate> { + CUTLASS_HOST_DEVICE + complex operator()(complex const &a) const { + return conj(a); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs, complex rhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()) - Output(rhs.real()); + Output y_i = Output(lhs.imag()) - Output(rhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Reduces value into the data pointed to by ptr (complex specialization) +template +struct red> { + CUTLASS_DEVICE + void operator()(complex *ptr, const complex &data) + { + data.red(ptr); + } +}; + + ////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index d33de182..76e96112 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -247,7 +247,7 @@ public: CUTLASS_HOST_DEVICE cutlass::Tensor4DCoord filter_extent() const { - return cutlass::Tensor4DCoord ({K, R, S, C}); + return cutlass::Tensor4DCoord ({K, R, S, C / groups}); } /// Returns output extent as Tensor4DCoord @@ -336,7 +336,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( return gemm::GemmCoord( problem_size.N * problem_size.P * problem_size.Q, problem_size.K, - problem_size.R * problem_size.S * problem_size.C + problem_size.R * problem_size.S * problem_size.C / problem_size.groups ); case Operator::kDgrad: return gemm::GemmCoord( @@ -451,6 +451,18 @@ int implicit_gemm_k_iterations( default: break; } + } else if (algorithm == IteratorAlgorithm::kOptimized) { + // Current optimized iterator only support GroupMode::kSingleGroup + if (group_mode == GroupMode::kSingleGroup) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } } } @@ -459,6 +471,25 @@ int implicit_gemm_k_iterations( } +template +CUTLASS_HOST_DEVICE +int depthwise_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int n = problem_size.N; + int p = (problem_size.P + Output_P - 1) / Output_P; + int q = (problem_size.Q + Output_Q - 1) / Output_Q; + + int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + return iterations; +} + + CUTLASS_HOST_DEVICE int implicit_gemm_k_iterations_per_channel( Operator conv_operator, diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 372a60b9..bb6fd4df 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -100,14 +100,16 @@ enum class IteratorAlgorithm { kAnalytic, ///< functionally correct in all cases but lower performance kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) - kFewChannels ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFixedStrideDilation ///< Optimized for fixed stride and dilation }; /// Distinguishes among partial specializations that accelerate certain problems where convolution /// stride is unit. enum class StrideSupport { kStrided, ///< arbitrary convolution stride - kUnity ///< unit convolution stride + kUnity, ///< unit convolution stride + kFixed ///< fixed convolution stride }; /// Identifies split-K mode @@ -125,6 +127,38 @@ enum class GroupMode { kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a tensor +template < + int N = 1, + int H = 1, + int W = 1, + int C = 1 +> +struct TensorNHWCShape { + static int const kN = N; + static int const kH = H; + static int const kW = W; + static int const kC = C; + + static int const kHW = H * W; + static int const kNHW = N * kHW; + static int const kNHWC = N * H * W * C; + + static int const kCount = kNHWC; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<4> toCoord() { + return make_Coord(kN, kH, kW, kC); + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace conv diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h new file mode 100644 index 00000000..502290bf --- /dev/null +++ b/include/cutlass/conv/device/direct_convolution.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Depthwise Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConvolution { +public: + + using UnderlyingKernel = DirectConvolutionKernel_; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + + using ReorderKernel = typename UnderlyingKernel::ReorderKernel; + + private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + DirectConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + if (kGroupMode != conv::GroupMode::kDepthwise) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K != args.problem_size.groups && + args.problem_size.C != args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.ptr_reordered_B = args.ref_reordered_B.data();; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // Launch reorder kernel + if (params_.ptr_reordered_B != nullptr) { + dim3 grid = ReorderKernel::get_grid_shape(params_); + dim3 block = ReorderKernel::get_block_shape(); + + cutlass::Kernel<<>>(params_); + } + + // Launch main kernel + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + // Dynamic SMEM size based on input params. + int smem_size = int(params_.get_smem_size()); + + // Make sure we can use that much shared memory. + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status != cudaSuccess) + return Status::kErrorInternal; + + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + + int get_smem_size() { return int(params_.get_smem_size()); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index bac90f15..84d3e888 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -52,33 +52,33 @@ template class ImplicitGemmConvolution { public: - using ImplicitGemmKernel = ImplicitGemmKernel_; + using UnderlyingKernel = ImplicitGemmKernel_; - using ElementA = typename ImplicitGemmKernel::ElementA; - using LayoutA = typename ImplicitGemmKernel::LayoutA; - using ElementB = typename ImplicitGemmKernel::ElementB; - using LayoutB = typename ImplicitGemmKernel::LayoutB; - using ElementC = typename ImplicitGemmKernel::ElementC; - using LayoutC = typename ImplicitGemmKernel::LayoutC; - using ElementAccumulator = typename ImplicitGemmKernel::ElementAccumulator; - using ElementCompute = typename ImplicitGemmKernel::ElementCompute; - using OperatorClass = typename ImplicitGemmKernel::OperatorClass; - using ArchTag = typename ImplicitGemmKernel::ArchTag; - using ThreadblockShape = typename ImplicitGemmKernel::ThreadblockShape; - using WarpShape = typename ImplicitGemmKernel::WarpShape; - using InstructionShape = typename ImplicitGemmKernel::InstructionShape; - using ThreadblockSwizzle = typename ImplicitGemmKernel::ThreadblockSwizzle; - using EpilogueOutputOp = typename ImplicitGemmKernel::EpilogueOutputOp; - static int const kStages = ImplicitGemmKernel::kStages; - static int const kConvDim = ImplicitGemmKernel::kConvDim; - using WarpMmaOperator = typename ImplicitGemmKernel::WarpMmaOperator; - using ArchMmaOperator = typename ImplicitGemmKernel::ArchMmaOperator; - using MathOperator = typename ImplicitGemmKernel::MathOperator; + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; - static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm; - static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport; - static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode; + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; static int const kWarpCount = (ThreadblockShape::kM / WarpShape::kM) * @@ -86,12 +86,12 @@ public: (ThreadblockShape::kK / WarpShape::kK); /// Argument structure - using Arguments = typename ImplicitGemmKernel::Arguments; + using Arguments = typename UnderlyingKernel::Arguments; private: /// Kernel parameters object - typename ImplicitGemmKernel::Params params_; + typename UnderlyingKernel::Params params_; public: @@ -102,12 +102,12 @@ public: static Status can_implement(Arguments const &args) { // dispatch to iterators - Status status = ImplicitGemmKernel::Mma::IteratorA::can_implement(args.problem_size); + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); if (Status::kSuccess != status) { return status; } - status = ImplicitGemmKernel::Mma::IteratorB::can_implement(args.problem_size); + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); if (Status::kSuccess != status) { return status; } @@ -138,9 +138,15 @@ public: if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { return Status::kErrorInvalidProblem; } + + // current optimized iterator algo only supports SingleGroup mode + if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && + kGroupMode != conv::GroupMode::kSingleGroup) { + return Status::kErrorInvalidProblem; + } } - static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; if (kConvolutionalOperator == conv::Operator::kFprop) { if (args.problem_size.K % kAlignmentC) return Status::kErrorMisalignedOperand; @@ -249,15 +255,15 @@ public: } // initialize the params structure from the arguments - params_ = typename ImplicitGemmKernel::Params( + params_ = typename UnderlyingKernel::Params( args, static_cast(workspace) ); - int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); @@ -292,9 +298,9 @@ public: dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); dim3 block(32 * kWarpCount, 1, 1); - int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); - cutlass::Kernel<<>>(params_); + cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h index b17e7c59..5a4548f7 100644 --- a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -89,7 +89,7 @@ template < ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage -/// pipeline. +/// pipeline that supports all GroupMode. template < typename ElementA, typename LayoutA, @@ -135,6 +135,13 @@ struct DefaultConv2dGroupFprop < AlignmentB > { + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + // Define the core components from GEMM using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, @@ -215,6 +222,267 @@ struct DefaultConv2dGroupFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage +/// pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and +/// 2 stage pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(std::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace kernel } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/kernel/default_depthwise_fprop.h b/include/cutlass/conv/kernel/default_depthwise_fprop.h index b4005a4a..3fbbcea2 100644 --- a/include/cutlass/conv/kernel/default_depthwise_fprop.h +++ b/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -39,14 +39,21 @@ #include "cutlass/cutlass.h" #include "cutlass/conv/kernel/default_conv2d.h" +#include "cutlass/conv/kernel/direct_convolution.h" #include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" - #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" #include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" +// Direct Conv Related Header files +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" + +#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -54,7 +61,7 @@ namespace conv { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop +/// Defines a kernel for DepthwiseFprop template < typename ElementA, typename LayoutA, @@ -80,12 +87,43 @@ template < int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value > struct DefaultDepthwiseFprop; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop with direct convolution algorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + // MatrixShape + typename StrideShape = cutlass::MatrixShape<-1, -1>, + // MatrixShape< Height, Width> + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDepthwiseDirect2dConvFprop; + ///////////////////////////////////////////////////////////////////////////////////////////////// // OpClassSimt convolutions ///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm, -/// 2 stage pipeline, and FFMA-based mainloop for SM50 +/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm template < typename ElementA, typename LayoutA, @@ -210,6 +248,338 @@ struct DefaultDepthwiseFprop < }; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB, +> { + + + + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); + static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); + + // Activations loaded by threadblock + static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ActivationShape = + cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape, + DilationShape, + ActivationShape>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + StrideShape, + DilationShape, + ActivationShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue, + IteratorAlgorithm::kFixedStrideDilation + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; } // namespace kernel } // namespace conv diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h new file mode 100644 index 00000000..f7451295 --- /dev/null +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -0,0 +1,505 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multi-staged Depthwise Convolution kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template > ///! OutputShape per ThreadBlock +struct DirectConvolutionParams { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + using Arguments = Arguments_; + using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; + + using ThreadblockShape = typename Mma::Shape; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static conv::GroupMode const kGroupMode = GroupMode_; + static int const kStages = Mma::kStages; + + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + int smem_size_; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Mma::IteratorB::Element *ptr_reordered_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + int split_k_slices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} + + /// + CUTLASS_HOST_DEVICE + DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) + : problem_size(args.problem_size), + implicit_gemm_problem_size( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), + ptr_B(args.ref_B.data()), + ptr_reordered_B(args.ref_reordered_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + split_k_slices(args.problem_size.split_k_slices) { + gemm_k_iterations = + depthwise_gemm_k_iterations(kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + + // Dynamic SMEM usage because stride and dilation are runtime params. + smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); + } + + CUTLASS_HOST_DEVICE + int get_smem_size() { + // Dynamic Smem Size + return smem_size_; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ReorderKernel { + using Params = Params_; + using ElementB = ElementB_; + + union SharedStorage {}; + + static unsigned int const kReorderKernelThreadPerCTA = 128; + + CUTLASS_HOST_DEVICE + ReorderKernel() {} + + CUTLASS_HOST_DEVICE + static dim3 get_grid_shape(Params const ¶ms) { + return dim3{static_cast( + (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / + kReorderKernelThreadPerCTA), + 1, + 1}; + } + + CUTLASS_HOST_DEVICE + static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } + + CUTLASS_HOST_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + int64_t m = static_cast(params.problem_size.groups); + int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); + const ElementB *src_with_type = static_cast(params.ptr_B); + ElementB *dst_with_type = static_cast(params.ptr_reordered_B); + + int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; + int64_t index_m = linear_index / n; + int64_t index_n = linear_index % n; + int64_t new_linear_index = index_m + index_n * m; + + if (linear_index < m * n) { + dst_with_type[new_linear_index] = src_with_type[linear_index]; + } + return; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +struct DirectConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefB ref_reordered_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + TensorRefB const & ref_reordered_B = nullptr, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + ref_reordered_B(ref_reordered_B), + split_k_mode(split_k_mode) + { + + } + + }; + + using Params = + typename cutlass::conv::kernel::DirectConvolutionParams; + + using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if threadblock is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_column_offset = 0; + int filter_row_offset = 0; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode == GroupMode::kDepthwise) { + iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + iterator_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_reordered_B, + thread_idx, + MatrixCoord( + filter_row_offset, + iterator_column_offset + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + + // Compute threadblock-scoped matrix multiply-add + // Epilogue is fused in the mainloop + mma(params.gemm_k_iterations, + accumulators, + iterator_A, + params.iterator_A, + iterator_B, + params.iterator_B, + accumulators, + epilogue, + output_op, + iterator_D, + iterator_C, + params.split_k_slices); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/thread/depthwise_mma.h b/include/cutlass/conv/thread/depthwise_mma.h new file mode 100644 index 00000000..905f0ec5 --- /dev/null +++ b/include/cutlass/conv/thread/depthwise_mma.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for depthwise convolution +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/thread/mma.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// MMA operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Inner product operator + typename Operator +> +struct ElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// General implementation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_> +struct ElementwiseInnerProduct { + using Shape = Shape_; + using Operator = arch::OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_HOST_DEVICE + void operator()(Array &d, + Array const &a, + Array const &b, + Array const &c) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Shape::kN; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization of half_t +template <> +struct ElementwiseInnerProduct< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + half_t, + half_t, + arch::OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = arch::OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 tmp_D = __hfma2(A, B, C); + + d = reinterpret_cast const &>(tmp_D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Concept: arch::OpMultiplyAdd or arch::Mma<> + typename Operator = arch::OpMultiplyAdd, + /// Used for partial specialization + typename Enable = bool +> +struct DepthwiseDirectConvElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gemplate that handles all packed matrix layouts +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_, + /// Operator used to compute GEMM + typename Operator_ +> +struct DepthwiseDirectConvElementwiseInnerProductGeneric { + + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = Operator_; + + /// A operand storage + using FragmentA = Array; + + /// B operand storage + using FragmentB = Array; + + /// C operand storage + using FragmentC = Array; + + /// Instruction + using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< + gemm::GemmShape, + 1, + ElementA, + ElementB, + ElementC, + Operator>; + + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + Array *ptr_D = reinterpret_cast *>(&D); + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array const *ptr_B = + reinterpret_cast const *>(&B); + + MmaOp mma_op; + + // Copy accumulators + D = C; + + // Compute matrix product + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { + + Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpB = ptr_B[n]; + + mma_op(tmpD, tmpA, tmpB, tmpD); + + ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_ +> +struct DepthwiseDirectConvElementwiseInnerProduct< + Shape_, + ElementA_, + ElementB_, + ElementC_, + arch::OpMultiplyAdd + > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = arch::OpMultiplyAdd; + + /// A operand storage + using FragmentA = + Array; // output_tile_size per thread * groups_per_thread + + /// B operand storage + using FragmentB = Array; // 1 * groups_per_thread + + /// C operand storage + using FragmentC = + Array; // output_tile_size per thread * groups_per_thread + + static bool const use_optimized = 0; + + using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + + ArchMmaOperator mma; + + mma(D, A, B, C); + + } +}; + +} // namespace thread +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 2f6a1243..1f1082f5 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -145,6 +145,7 @@ private: uint32_t predicates_[kAccessesPerVector]; int filter_rs_; int filter_c_; + int channels_per_group_; // // Assertions @@ -175,6 +176,7 @@ public: filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); Index column = threadblock_offset.column() + thread_coord.strided(); + channels_per_group_ = problem_size_.C / problem_size_.groups; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { @@ -188,7 +190,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); } pointer_ += ( @@ -229,7 +231,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); } pointer_ += next; diff --git a/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h new file mode 100644 index 00000000..1c244923 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvParams; + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvFilterIteratorParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template<> +struct Depthwise2dFpropDirectConvParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int32_t activation_tile_h; + int32_t activation_tile_w; + int32_t activation_tile_hw; + FastDivmod activation_tile_w_divmod; + + int filter[2]; + int stride[2]; + int dilation[2]; + int inc_next[2]; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_load_count; + int activation_storage_elements; + int activation_size; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams() { } + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + MatrixCoord threadblock_shape, ///< CTA threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int element_size_bits, ///< bits of activation element + const int thread_count, ///< threads per threadblock + const int thread_count_contiguous, ///< number of threads for continuous dimension + const int element_per_load) ///< element per each load + : layout(layout) { + + filter[0] = problem_size.S; + filter[1] = problem_size.R; + + stride[0] = problem_size.stride_w; + stride[1] = problem_size.stride_h; + + dilation[0] = problem_size.dilation_w; + dilation[1] = problem_size.dilation_h; + + // Compute activation_tile size per threadblock because stride and dilation are runtime params. + activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + + (problem_size.R - 1) * problem_size.dilation_h + 1; + activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + + (problem_size.S - 1) * problem_size.dilation_w + 1; + activation_tile_hw = activation_tile_h * activation_tile_w; + + activation_tile_w_divmod = FastDivmod(activation_tile_w); + + /// Below two values could not be templatized because the stride and dilation are runtime params + activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; + activation_storage_elements = activation_load_count * element_per_load * thread_count; + activation_size = activation_storage_elements * element_size_bits / 8; + + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / + (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + + // next S + inc_next[0] = problem_size.dilation_w; + // next R + inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template <> +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int activation_size_ ///< Activation size loaded by iterator + ) + : layout(layout), + activation_size(activation_size_) { + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = + (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template <> +struct Depthwise2dFpropDirectConvFilterIteratorParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + int filter_size; + + bool is_convolution; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + const int filter_size_) ///< Filter size loaded by iterator + : layout(layout), + filter_size(filter_size_), + is_convolution(problem_size.mode == Mode::kConvolution){} +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h new file mode 100644 index 00000000..7735b667 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h @@ -0,0 +1,314 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + // Compilation value of stride , dialtion and activation shape + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ActivationShape = ActivationShape_; + + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // activation N is inside the Tensor or not + bool valid_n_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0) { + + base_c_ = threadblock_offset.column(); + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + + valid_n_ = activatioin_base_.n() < problem_size_.N; + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + kActivationSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h = next / ActivationShape::kW; + int w = next % ActivationShape::kW; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + bool valid_c = coord.c() < problem_size_.C; + bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; + bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; + return valid_n_ ? valid_c & valid_h & valid_w : 0; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < ThreadMap::Iterations::kCount) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kActivationSize; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return ThreadMap::Iterations::kCount; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check stride and dilation constraint + if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 00000000..e35b7416 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // Number of loads for activations tensor X. + const int number_of_loads_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0), + number_of_loads_(params.activation_load_count) { + + base_c_ = threadblock_offset.column(); + + set_activation_coord(offset_intial_npq_); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + sizeof_bits::value, + ThreadMap::kThreads, + ThreadMap::Detail::ShapeVec::kContiguous, + ThreadMap::kElementsPerAccess); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h, w; + params_.activation_tile_w_divmod(h, w, next) ; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < number_of_loads_) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return params_.activation_size; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return number_of_loads_; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h new file mode 100644 index 00000000..1d8fcb3f --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Epilogue stores the data into global memory + typename Epilogue_, + /// iterator implementation variants + conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseFpropDirectConvMultipleStage : + public DepthwiseDirectConvMmaBase { +public: + ///< Base class + using Base = DepthwiseDirectConvMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using Epilogue = Epilogue_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropDirectConvMultipleStage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + // Number of iterators is a static value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA &iterator_A, + ///< Params of global memory iterator + typename IteratorA::Params const &iterator_a_params, + ///< iterator over B operand in global memory + IteratorB &iterator_B, + ///< Params of global memory iterator + typename IteratorB::Params const &iterator_b_params, + ///< initial value of accumulator + FragmentC const &src_accum, + /// Epilogue + Epilogue &epilogue, + ///< Output operator + typename Epilogue::OutputOp const &output_op, + ///< Tile iterator for destination + typename Epilogue::OutputTileIterator &destination_iterator, + ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + typename Epilogue::OutputTileIterator &source_iterator, + + int split_k_slices = 1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + if (stage == 0) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + } + + if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ + // Number of iterators is compilation static. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); + this->smem_iterator_A_.set_iteration_index(0); + + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + // Move to the next stage + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + unsigned int iterations = 0; + constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. + + accum.clear(); + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { + if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k == 0) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k < Base::kWarpGemmIterations) { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == inner_loop_iterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == inner_loop_iterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next cta + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); + + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); + smem_read_stage_idx = 0; + } else { + this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); + ++smem_read_stage_idx; + } + + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + } + + // goback to start position. B has no multiple stage + this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); + + --gemm_k_iterations; + } + } + + // + // Epilogue + // + int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); + + destination_iterator.set_tile_index(iterations * split_k_slices); + + source_iterator.set_tile_index(iterations * split_k_slices); + + epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); + + ++iterations; + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 00000000..72107e5e --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +template > +class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { +public: + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + using Params = Depthwise2dFpropDirectConvFilterIteratorParams; + + protected: + + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int offset_trs_[ThreadMap::Iterations::kStrided]; + +public: + + + + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Do nothing because the filter is persistent in the SMEM + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + int trs = offset_trs_[iteration_strided_]; + + return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && + coord.h() < Shape::kColumn; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + int64_t offset = coord.n(); + if (params_.is_convolution) { + offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; + } else { + offset += coord.h() * problem_size_.K; + } + + return reinterpret_cast(pointer_ + + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines the filter size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kFilterSize; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // check whether runtime filter size is same as templated filter size. + if ((problem_size.R * problem_size.S) != Shape::kColumn) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_mma_base.h b/include/cutlass/conv/threadblock/depthwise_mma_base.h new file mode 100644 index 00000000..96c76bee --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_mma_base.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a directconv threadblock-scoped Depthwise kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy object describing MmaTensorOp +template < + /// Warp-level GEMM operator (concept: gemm::warp::Mma) + typename Operator_, + /// Padding used for A operand in shared memory (concept: MatrixShape) + typename SmemPaddingA_, + /// Padding used for B operand in shared memory (concept: MatrixShape) + typename SmemPaddingB_, + /// + typename ThreadMapA_, + /// + typename ThreadMapB_, + /// Number of partitions of K dimension of GEMM + int PartitionsK = 1> +struct DepthwiseDirectConvMmaPolicy { + /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) + using Operator = Operator_; + + /// Padding used for A operand in shared memory + using SmemPaddingA = SmemPaddingA_; + + /// Padding used for B operand in shared memory + using SmemPaddingB = SmemPaddingB_; + + using ThreadMapA = ThreadMapA_; + using ThreadMapB = ThreadMapB_; + + /// Number of partitions of K dimension + static int const kPartitionsK = PartitionsK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseDirectConvMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm:: + GemmShape; + + /// Number of warp-level GEMM oeprations + /// kWarpGemmIterations could be even and odd. + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape<1, // Not determined at compile-time :( + Shape::kN + Policy::SmemPaddingA::kRow>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; // Tile N = 64? + + public: + // + // Data members + // + + // Let persistent B matrix in front of dynamic matrix A + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand + /// Not be determined at compile-time -- Just to get a Smem start address. + AlignedBuffer operand_A; + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseDirectConvMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h index f13bdc32..64d443ca 100644 --- a/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h +++ b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -44,11 +44,17 @@ #include "cutlass/matrix_shape.h" #include "cutlass/gemm/warp/mma.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/warp/mma_depthwise_simt.h" + #include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/threadblock/mma_singlestage.h" #include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/conv/warp/mma_depthwise_simt.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" #include "cutlass/arch/cache_operation.h" @@ -58,6 +64,95 @@ namespace cutlass { namespace conv { namespace threadblock { +namespace detail { +// +// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by +// each partitions within warp. +// The goal is for each thread's tile of elements to be as square as +// possible for performance (4x4 will be faster than 2x8). +template // The number of partitions within the warp +struct SimtWarpShape { + // kP * kQ * WarpNumThreadsM = WarpShapeM + // If needed, enable more specializations. +}; +template <> +struct SimtWarpShape<4, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<8, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<8, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; +template <> +struct SimtWarpShape<8, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<16, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template +struct SimtWarpShape<25, WarpNumThreadsM> { + static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); + static constexpr int kP = 5; + static constexpr int kQ = 5; +}; + +template <> +struct SimtWarpShape<32, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 8; +}; + +template <> +struct SimtWarpShape<32, 2> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; + +template <> +struct SimtWarpShape<32, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; + +} // namespace detail + template < /// Shape of threadblock-scoped matrix multiply operator typename Shape, @@ -114,6 +209,74 @@ struct DepthwiseMmaCoreWithLaneAccessSize; ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of threadblock-scoped output tile + typename ThreadBlockOutputShape, + /// Shape of filter shape per threadblock + typename FilterShape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + template < /// Shape of threadblock-scoped matrix multiply operator typename Shape, @@ -332,6 +495,458 @@ struct DepthwiseMmaCoreWithLaneAccessSize; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + true // Dynamic iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + // Get output P, Q per thread + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape_, + DilationShape_, + ActivationShape_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using ActivationShape = ActivationShape_; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + false // static iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type + StrideShape, /// Stride ( MatrixShape ) + DilationShape, /// Dilation ( MatrixShape ) + ActivationShape /// Activation Shape loaded by threadblock + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; } // namespace threadblock } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/threadblock/threadblock_swizzle.h b/include/cutlass/conv/threadblock/threadblock_swizzle.h index f1551a0a..6e3eae22 100644 --- a/include/cutlass/conv/threadblock/threadblock_swizzle.h +++ b/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -165,7 +165,29 @@ struct StridedDgradIdentityThreadblockSwizzle : ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for GEMMs +template +struct DepthwiseDirect2dConvIdentityThreadblockSwizzle + : public gemm::threadblock::GemmIdentityThreadblockSwizzle { + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} + + /// Returns the shape of the problem in units of logical tiles + CUTLASS_HOST_DEVICE + gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) const { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return gemm::GemmCoord(1, + (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), + split_k_slices); + } +}; } // namespace threadblock -} // namespace gemm +} // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/warp/mma_depthwise_simt.h b/include/cutlass/conv/warp/mma_depthwise_simt.h index bd0b827d..f4ddf441 100644 --- a/include/cutlass/conv/warp/mma_depthwise_simt.h +++ b/include/cutlass/conv/warp/mma_depthwise_simt.h @@ -42,6 +42,9 @@ #include "cutlass/gemm/warp/mma.h" #include "cutlass/gemm/thread/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/thread/depthwise_mma.h" + #include "cutlass/gemm/warp/mma_simt_tile_iterator.h" #include "cutlass/gemm/warp/mma_simt_policy.h" @@ -91,7 +94,7 @@ class MmaDepthwiseSimt public: /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; // < 64, 16 , 8> + using Shape = Shape_; /// Data type of multiplicand A using ElementA = ElementA_; @@ -156,8 +159,223 @@ public: MmaDepthwiseSimt():Base() {} }; +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + typename FilterShape_, + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + typename ThreadOutputShape_, + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape_ = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape_ = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseDirectConvSimt { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + using FilterShape = FilterShape_; + + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Iterator algo type + static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || + platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && + platform::is_same< ElementA, int8_t >::value && + platform::is_same< ElementB, int8_t >::value; + + using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; + + /// Thread-level matrix multiply accumulate operator + using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< + cutlass::gemm::GemmShape< + Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread + Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread + 1>, + ElementA, + ElementB, + ElementC, + arch::OpMultiplyAdd, + dp4a_type + >; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Shape of the underlying instruction + using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; + +public: + + /// Iterates over the A operand in memory + using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< + MatrixShape, // per warp + FilterShape, + ThreadOutputShape, + ThreadBlockOutputShape, + cutlass::gemm::Operand::kA, + ElementA, + Policy, + IteratorAlgorithm, + StrideShape, + DilationShape, + ActivationShape, + PartitionsK, + Shape::kK + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape<1, Shape::kN>, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + + /// Iterates over the C operand in memory + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kC, + ElementC, + LayoutC, + Policy + >; + + /// Storage for C tile + using FragmentC = typename ThreadMma::FragmentC; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseDirectConvSimt() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &d, + FragmentA a, + FragmentB b, + FragmentC const &c, int group_idx = 0) const { + + ThreadMma mma; + + mma(d, a, b, c); + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + //TODO: Implement this + dst_A = A; + dst_B = B; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp -} // namespace gemm +} // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h index 9fec53f1..e2524d13 100644 --- a/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h +++ b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -40,6 +40,8 @@ #include "cutlass/tensor_ref.h" #include "cutlass/matrix_shape.h" +#include "cutlass/conv/convolution.h" + #include "cutlass/arch/memory_sm75.h" #include "cutlass/layout/matrix.h" @@ -250,6 +252,611 @@ private: /////////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadBlockOutputShape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1> +class DepthwiseDirect2dConvSimtTileIterator; + + +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm, + /// Stride ( MatrixShape ) + typename StrideShape, + /// Dilation ( MatrixShape ) + typename DilationShape, + /// Activation Shape loaded by threadblock + typename ActivationShape, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + +// Thread-level shape of a fragment + using ThreadShape = MatrixShape< + ThreadOutputShape::kNHW, // Output tile shape Computed by current threads + ThreadOutputShape::kC + >; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = MatrixShape< + ThreadShape::kRow, + ThreadShape::kColumn / Policy::LaneMmaShape::kN + >; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + +protected: + + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + int iterator_offset_; + + int inc_next_s_ ; + int inc_next_r_ ; + + MatrixCoord lane_offset_; +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator( + TensorRef ref, + int lane_id + ) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + iterator_offset_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE + void setup_initial_status(Params const& params) { + + inc_next_s_ = params.inc_next[0]; + inc_next_r_ = params.inc_next[1]; + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_p_ = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + int base_q_ = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int base_w = (base_q_ + q) * params.stride[0]; + int base_h = (base_p_ + p) * params.stride[1]; + + int offset = base_h * params.activation_tile_w + base_w; + activation_offset[p][q][col] = offset; + } + } + } + } + + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + iterator_offset_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + iterator_offset_ += inc_next_s_; + + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + iterator_offset_ += inc_next_r_; + return *this; + } + + iterator_r_ = 0; + iterator_offset_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator & operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + Array *dst_ptr = + reinterpret_cast *>(&frag); + + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + void const *ptr = ref_.data() + + ref_.offset({activation_offset[p][q][n] + (iterator_offset_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged " + "along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, + "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + + // Activations loaded by threadblock + static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ThreadActivationShape = cutlass::conv:: + TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; + + // Thread-level shape of a fragment + using ThreadShape = + MatrixShape; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = + MatrixShape; + + using ThreadTileCount = MatrixShape; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + + protected: + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + Array + activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + + + MatrixCoord lane_offset_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE void setup_initial_status( + Params const ¶ms) { + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_h = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; + int base_w = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadActivationShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadActivationShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (base_h + h) * ActivationShape::kW + (base_w + w); + + void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); + arch::shared_load(activation[h][w][col], ptr); + } + } + } + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = + MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + return *this; + } + + iterator_r_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; + const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; + + dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + } // namespace warp -} // namespace gemm +} // namespace conv } // namespace cutlass diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 1fe8ec0f..e5c0b9dd 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -100,12 +100,21 @@ public: } } + /// Constructs from some other Coord + template + CUTLASS_HOST_DEVICE + Coord(Coord other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + /// Returns a slice of the Coord which may be larger or smaller in rank /// than this. template CUTLASS_HOST_DEVICE - Coord slice(int start = 0, Index identity = 0) const { - Coord result; + Coord slice(int start = 0, Index identity = 0) const { + Coord result; for (int i = 0; i < Slice; ++i) { if (i + start < kRank) { result[i] = idx[i + start]; diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index eef43602..d61c9d05 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -59,7 +59,9 @@ inline std::ostream &operator<<(std::ostream &out, dim3 d) { /// Output operator for CUDA built-in error type inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) return out << cudaGetErrorString(error); +#endif } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -252,8 +254,9 @@ namespace conv { inline std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl - << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl + << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "groups: (" << problem.groups << ")" << std::endl << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 1de33024..7a7ebf8a 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -57,6 +57,23 @@ void Kernel(typename Operator::Params params) { op(params, *shared_storage); } + +/// Generic CUTLASS kernel template. +template +__global__ +void Kernel2(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator::invoke(params, *shared_storage); + +} + + //////////////////////////////////////////////////////////////////////////////// } /// namespace cutlass diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index d694ea8f..4abbc34d 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -104,15 +104,15 @@ struct Identity { template struct Identity > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { - return rhs; + Array operator()(Array const &value) const { + return value; } using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -183,7 +183,7 @@ struct LeakyReLU { Params(): LinearCombinationGenericParams(), leaky_alpha(T(1)) {} - + CUTLASS_HOST_DEVICE Params( T alpha, @@ -228,21 +228,21 @@ struct LeakyReLU > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, T const & alpha_recip) const { + Array operator()(Array const &value, T const & alpha_recip) const { Array y; LeakyReLU leaky_op; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(rhs.size()); ++i) { - y[i] = leaky_op(rhs[i], alpha_recip); + for (int i = 0; i < int(value.size()); ++i) { + y[i] = leaky_op(value[i], alpha_recip); } return y; } CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs, params_.leaky_alpha); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value, params_.leaky_alpha); } }; @@ -265,13 +265,13 @@ struct Tanh { template struct Tanh > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Array y; Tanh tanh_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - y[i] = tanh_op(rhs[i]); + y[i] = tanh_op(value[i]); } return y; @@ -280,8 +280,8 @@ struct Tanh > { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -299,8 +299,8 @@ struct Tanh> { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -323,13 +323,13 @@ struct Sigmoid { template struct Sigmoid > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Array y; Sigmoid sigmoid_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - y[i] = sigmoid_op(rhs[i]); + y[i] = sigmoid_op(value[i]); } return y; @@ -338,8 +338,8 @@ struct Sigmoid > { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -398,17 +398,17 @@ struct SiLu { template struct SiLu> { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Sigmoid> sigmoid_op; multiplies> mul; - return mul(rhs, sigmoid_op(rhs)); + return mul(value, sigmoid_op(value)); } using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -458,13 +458,13 @@ struct HardSwish { template struct HardSwish > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Array y; HardSwish hardswish_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - y[i] = hardswish_op(rhs[i]); + y[i] = hardswish_op(value[i]); } return y; @@ -483,13 +483,13 @@ struct HardSwish > { using T = half_t; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { minimum > mn; maximum > mx; multiplies > mul; plus > add; - return mul(mul(mn(mx(add(rhs, T(3)), T(0)), T(6)), rhs), T(0.16666667f)); + return mul(mul(mn(mx(add(value, T(3)), T(0)), T(6)), value), T(0.16666667f)); } using Params = LinearCombinationGenericParams; @@ -561,13 +561,13 @@ struct GELU { template struct GELU > { CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Array y; GELU gelu_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - y[i] = gelu_op(rhs[i]); + y[i] = gelu_op(value[i]); } return y; @@ -576,8 +576,8 @@ struct GELU > { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -601,7 +601,6 @@ struct GELU_taylor { T operator()(T const &scalar, Params const ¶ms_) const { return this->operator()(scalar); } - }; template @@ -632,8 +631,8 @@ struct GELU_taylor > { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; @@ -641,13 +640,13 @@ template struct GELU_taylor > { static const bool kIsHeavy=true; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs) const { + Array operator()(Array const &value) const { Array y; GELU_taylor gelu_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - y[i] = gelu_op(rhs[i]); + y[i] = gelu_op(value[i]); } return y; @@ -656,8 +655,8 @@ struct GELU_taylor > { using Params = LinearCombinationGenericParams; CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, Params const ¶ms_) const { - return this->operator()(rhs); + Array operator()(Array const &value, Params const ¶ms_) const { + return this->operator()(value); } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index d85384f5..86408536 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -78,6 +78,9 @@ public: using ElementwiseOp = ElementwiseOp_; using BinaryOp = BinaryOp_; + // Indicates that this epilogue applies only one binary operation + static bool const kIsSingleSource = true; + using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentC = Array; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index 63168330..64c3233a 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -223,6 +223,9 @@ public: using ElementwiseOp = ReLu; using BinaryOp = plus; + // Indicates that this epilogue applies only one binary operation + static bool const kIsSingleSource = true; + using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentC = Array; diff --git a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h index 111fc53e..f79511f3 100644 --- a/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_leaky_relu.h @@ -37,6 +37,7 @@ #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block.h b/include/cutlass/epilogue/thread/linear_combination_residual_block.h index c32ab7c4..6bb560bd 100644 --- a/include/cutlass/epilogue/thread/linear_combination_residual_block.h +++ b/include/cutlass/epilogue/thread/linear_combination_residual_block.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief Epilogue functor specialized for residual blocks in deep neural network. + \brief Epilogue functor specialized for residual blocks in deep neural networks. */ #pragma once @@ -45,14 +45,24 @@ namespace cutlass { namespace epilogue { namespace thread { -// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) +namespace detail { + +/// Dummy class used to designate that the second binary operator in the epilogue is unsued +template +class NoOp {}; + +} + +/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) template class ActivationOp_, - template class BinaryOp_, - template class UnaryOp_> + template class BinaryOp1_, + template class UnaryOp_, + template class BinaryOp2_ = detail::NoOp> class LinearCombinationResidualBlock { public: + static bool const kIsSingleSource = false; using ElementOutput = ElementC_; using ElementC = ElementC_; @@ -62,7 +72,130 @@ public: static int const kCount = kElementsPerAccess; using UnaryOp = UnaryOp_>; - using BinaryOp = BinaryOp_>; + using BinaryOp1 = BinaryOp1_>; + using BinaryOp2 = BinaryOp2_>; + using ActivationOp = ActivationOp_>; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentOutput = Array; + + using ElementZ = ElementOutput_; + using ElementT = ElementZ; + using FragmentZ = Array; + using FragmentT = Array; + + static bool const kIsHeavy = true; + static bool const kStoreZ = true; + static bool const kStoreT = false; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales residual input + ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory + + CUTLASS_HOST_DEVICE + Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute alpha, ElementCompute beta) + : alpha(alpha), beta(beta) {} + + CUTLASS_HOST_DEVICE + Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) + : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} + }; + +private: + + ElementCompute alpha_; + ElementCompute beta_; + bool skip_elementwise_; + +public: + + /// Constructor from Params + CUTLASS_HOST_DEVICE + LinearCombinationResidualBlock(Params const ¶ms) { + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + skip_elementwise_ = false; + } + + /// The "source" tensor corresponds to the residual input + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + /// Functionally required for serial reduction in the epilogue + /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + + if (k_partition != k_partition_count - 1) { + skip_elementwise_ = true; + } + } + + /// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2)) + CUTLASS_HOST_DEVICE + void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, + FragmentC const &residual1, FragmentC const &residual2, + FragmentCompute const &bias) const { + UnaryOp unary_op; + BinaryOp1 binary_op1; + BinaryOp2 binary_op2; + ActivationOp activation; + + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_residual1 = + NumericArrayConverter()(residual1); + FragmentCompute tmp_residual2 = + NumericArrayConverter()(residual2); + + FragmentCompute z = + binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2); + FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); + + NumericArrayConverter convert_z; + frag_Z = convert_z(result_Z); + } + + /// Should never be called + CUTLASS_HOST_DEVICE + void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, + FragmentCompute const &) const {} +}; + +/// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) +template class ActivationOp_, + template class BinaryOp1_, + template class UnaryOp_> +class LinearCombinationResidualBlock { +public: + static bool const kIsSingleSource = true; + + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using UnaryOp = UnaryOp_>; + using BinaryOp = BinaryOp1_>; using ActivationOp = ActivationOp_>; using FragmentAccumulator = Array; diff --git a/include/cutlass/epilogue/thread/linear_combination_residual_block_v2.h b/include/cutlass/epilogue/thread/linear_combination_residual_block_v2.h deleted file mode 100644 index 72013ab9..00000000 --- a/include/cutlass/epilogue/thread/linear_combination_residual_block_v2.h +++ /dev/null @@ -1,197 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Epilogue functor specialized for residual blocks in deep neural network. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -// /// Models a residual block of the form: UnaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual)) -// or form UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) -template class ActivationOp_, - template class BinaryOp1_, - template class UnaryOp_, - template class BinaryOp2_=BinaryOp1_> -class LinearCombinationResidualBlockV2 { -public: - - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using UnaryOp = UnaryOp_>; - using BinaryOp1 = BinaryOp1_>; - using BinaryOp2 = BinaryOp2_>; - using ActivationOp = ActivationOp_>; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentOutput = Array; - - using ElementZ = ElementOutput_; - using ElementT = ElementZ; - using FragmentZ = Array; - using FragmentT = Array; - - static bool const kIsHeavy = true; - static bool const kStoreZ = true; - static bool const kStoreT = false; - - /// Host-constructable parameters structure - struct Params { - - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales residual input - ElementCompute const *alpha_ptr{nullptr}; ///< pointer to accumulator scalar - if not null, loads it from memory - ElementCompute const *beta_ptr{nullptr}; ///< pointer to residual scalar - if not null, loads it from memory - - CUTLASS_HOST_DEVICE - Params() : alpha(ElementCompute(1)), beta(ElementCompute(1)) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute alpha, ElementCompute beta) - : alpha(alpha), beta(beta) {} - - CUTLASS_HOST_DEVICE - Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr) - : alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {} - }; - -private: - - ElementCompute alpha_; - ElementCompute beta_; - bool skip_elementwise_; - -public: - - /// Constructor from Params - CUTLASS_HOST_DEVICE - LinearCombinationResidualBlockV2(Params const ¶ms) { - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); - beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - skip_elementwise_ = false; - } - - /// The "source" tensor corresponds to the residual input - CUTLASS_HOST_DEVICE - bool is_source_needed() const { return true; } - - /// Functionally required for serial reduction in the epilogue - /// IMPORTANT: Split-k is supported only when ActivationOp is Identity. - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - if (k_partition) { - beta_ = ElementCompute(1); - } - - if (k_partition != k_partition_count - 1) { - skip_elementwise_ = true; - } - } - - /// Applies the operation UnaryOp(BinaryOp(ActivationOp(AB + bias), residual)) - CUTLASS_HOST_DEVICE - void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, - FragmentC const &residual, - FragmentCompute const &bias) const { - UnaryOp unary_op; - BinaryOp1 binary_op; - ActivationOp activation; - - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_residual = - NumericArrayConverter()(residual); - - FragmentCompute z = - binary_op(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual); - FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); - - NumericArrayConverter convert_z; - frag_Z = convert_z(result_Z); - } - - /// Applies the operation UnaryOp(BinaryOp(BinaryOp(ActivationOp(AB + bias), residual1), residual2)) - CUTLASS_HOST_DEVICE - void operator()(FragmentOutput &frag_Z, FragmentOutput &, FragmentAccumulator const &AB, - FragmentC const &residual1, FragmentC const &residual2, - FragmentCompute const &bias) const { - UnaryOp unary_op; - BinaryOp1 binary_op1; - BinaryOp2 binary_op2; - ActivationOp activation; - - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_residual1 = - NumericArrayConverter()(residual1); - FragmentCompute tmp_residual2 = - NumericArrayConverter()(residual2); - - FragmentCompute z = - binary_op2(binary_op1(activation(alpha_ * tmp_Accum + bias), beta_ * tmp_residual1), beta_ * tmp_residual2); - FragmentCompute result_Z = skip_elementwise_ ? z : unary_op(z); - - NumericArrayConverter convert_z; - frag_Z = convert_z(result_Z); - } - - /// Should never be called - CUTLASS_HOST_DEVICE - void operator()(FragmentOutput &, FragmentOutput &, FragmentAccumulator const &, - FragmentCompute const &) const {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index cfaaa8b4..75f6c7f1 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -58,12 +58,16 @@ #include "cutlass/epilogue/warp/fragment_iterator_simt.h" #include "cutlass/epilogue/warp/tile_iterator_simt.h" #include "cutlass/epilogue/threadblock/default_thread_map_simt.h" +#include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h" #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/epilogue_depthwise.h" #include "cutlass/layout/permute.h" @@ -314,6 +318,100 @@ struct DefaultEpilogueSimtAffineRankN { }; ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for SimtOps. +template , + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> > +struct DefaultDirectConvEpilogueSimt { + using Shape = Shape_; + using WarpMmaSimt = WarpMmaSimt_; + using WarpShape = typename WarpMmaSimt::Shape; + using OutputOp = OutputOp_; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static int const kElementsPerAccess = ElementsPerAccess_; + + + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaSimt::LayoutC; + using ElementAccumulator = typename WarpMmaSimt::ElementC; + + /// Number of threads total + using WarpCount = gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN + >; + + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + static int const kThreads = WarpCount::kCount * kWarpSize; + + // + // Thread map + // + + using OutputTileThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccess + >; + + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorDirectConv< + OutputTileThreadMap, + ElementOutput, + ThreadOutputShape, + ThreadBlockOutputShape + >; + + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< + typename WarpMmaSimt::Shape, + typename WarpMmaSimt::ThreadMma, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorSimtDirect2dConv< + typename WarpMmaSimt::Shape, + ThreadOutputShape, + ThreadBlockOutputShape, + typename WarpMmaSimt::ThreadMma, + ElementAccumulator, + layout::RowMajor, + typename WarpMmaSimt::Policy + >; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner< + OutputTileThreadMap, + ElementAccumulator + >; + + /// Hard-coded padding elements added + using Padding = typename WarpTileIterator::Padding; + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::EpilogueDepthwise< + Shape, + ThreadOutputShape, + ThreadBlockOutputShape, + WarpMmaSimt, + OutputTileIterator, + AccumulatorFragmentIterator, + WarpTileIterator, + SharedLoadIterator, + OutputOp, + Padding + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index c232a2db..bf741f6e 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -293,6 +293,141 @@ struct DefaultIteratorsTensorOp< static int const kFragmentsPerIteration = 1; }; + +/// Partial specialization for float_e4m3_t <= float x 16/8 epilogues avoids shared memory bank conflicts. +/// Threadblock::kN = 256 still has bank conflicts. +template < + int ElementsPerAccess, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + cutlass::float_e4m3_t, + float, + ElementsPerAccess, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using ElementOutput = cutlass::float_e4m3_t; + + static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), + "ElementsPerAccess needs to be 16 or 8."); + + using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + float, + 32, + cutlass::sizeof_bits::value, + ElementsPerAccess, + 8 + >; + + using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + float, + layout::RowMajor + >; + + using WarpTileIterator = typename platform::conditional< + (ThreadblockShape::kN == 256), + WarpTileIteratorNotMixed, + WarpTileIteratorMixed>::type; + + using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + float, + 32, + cutlass::sizeof_bits::value, + ElementsPerAccess, + 8 + >; + + using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< + ThreadMap, + float + >; + + using SharedLoadIterator = typename platform::conditional< + (ThreadblockShape::kN == 256), + SharedLoadIteratorNotMixed, + SharedLoadIteratorMixed>::type; + + static int const kFragmentsPerIteration = 1; +}; + +/// Partial specialization for float_e5m2_t <= float x 16/8 epilogues avoids shared memory bank conflicts. +/// Threadblock::kN = 256 still has bank conflicts. +template < + int ElementsPerAccess, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename ThreadMap +> +struct DefaultIteratorsTensorOp< + cutlass::float_e5m2_t, + float, + ElementsPerAccess, + ThreadblockShape, + WarpShape, + InstructionShape, + ThreadMap> { + + using ElementOutput = cutlass::float_e5m2_t; + + static_assert((ElementsPerAccess == 16 || ElementsPerAccess == 8), + "ElementsPerAccess needs to be 16 or 8."); + + using WarpTileIteratorMixed = cutlass::epilogue::warp::TileIteratorTensorOpMixed< + WarpShape, + InstructionShape, + float, + 32, + cutlass::sizeof_bits::value, + ElementsPerAccess, + 8 + >; + + using WarpTileIteratorNotMixed = cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + float, + layout::RowMajor + >; + + using WarpTileIterator = typename platform::conditional< + (ThreadblockShape::kN == 256), + WarpTileIteratorNotMixed, + WarpTileIteratorMixed>::type; + + using SharedLoadIteratorMixed = cutlass::epilogue::threadblock::SharedLoadIteratorMixed< + ThreadMap, + float, + 32, + cutlass::sizeof_bits::value, + ElementsPerAccess, + 8 + >; + + using SharedLoadIteratorNotMixed = cutlass::epilogue::threadblock::SharedLoadIterator< + ThreadMap, + float + >; + + using SharedLoadIterator = typename platform::conditional< + (ThreadblockShape::kN == 256), + SharedLoadIteratorNotMixed, + SharedLoadIteratorMixed>::type; + + static int const kFragmentsPerIteration = 1; +}; + } // namespace detail //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h deleted file mode 100644 index a99c5167..00000000 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h +++ /dev/null @@ -1,177 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" -#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess, - bool ScatterD = false -> -struct DefaultEpilogueWithBroadcastTensorOpV2 { - - /// Use defaults related to the existing epilogue - using Base = DefaultEpilogueTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - ElementsPerAccess - >; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2< - typename Base::OutputTileThreadMap, - ElementOutput, - ScatterD - >; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2< - typename Base::OutputTileThreadMap, - ElementTensor - >; - - /// Define the epilogue - using Epilogue = EpilogueWithBroadcastV2< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputTileIterator, - TensorTileIterator, - ElementVector, - typename Base::AccumulatorFragmentIterator, - typename Base::WarpTileIterator, - typename Base::SharedLoadIterator, - OutputOp, - typename Base::Padding, - Base::kFragmentsPerIteration - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for VoltaTensorOps. -template < - typename Shape, - typename WarpMmaTensorOp, - int PartitionsK, - typename ElementOutput, - typename ElementTensor, - typename ElementVector, - typename OutputOp, - int ElementsPerAccess -> -struct DefaultEpilogueWithBroadcastVoltaTensorOpV2 { - - /// Use defaults related to the existing epilogue - using Base = DefaultEpilogueVoltaTensorOp< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputOp, - ElementsPerAccess - >; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2< - typename Base::OutputTileThreadMap, - ElementOutput - >; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorV2< - typename Base::OutputTileThreadMap, - ElementTensor - >; - - /// Define the epilogue - using Epilogue = EpilogueWithBroadcastV2< - Shape, - WarpMmaTensorOp, - PartitionsK, - OutputTileIterator, - TensorTileIterator, - ElementVector, - typename Base::AccumulatorFragmentIterator, - typename Base::WarpTileIterator, - typename Base::SharedLoadIterator, - OutputOp, - typename Base::Padding - >; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 9ee85637..d334702b 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -34,6 +34,7 @@ The epilogue rearranges the result of a matrix product through shared memory to match canonical tensor layouts in global memory. Epilogues support conversion and reduction operations. + The shared memory resource is time-sliced across warps. */ #pragma once @@ -59,8 +60,9 @@ #include "cutlass/transform/threadblock/regular_tile_iterator.h" #include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -#include "cutlass/numeric_types.h" +#include "cutlass/util/index_sequence.h" //////////////////////////////////////////////////////////////////////////////// @@ -68,6 +70,7 @@ namespace cutlass { namespace epilogue { namespace threadblock { + //////////////////////////////////////////////////////////////////////////////// /// Epilogue operator @@ -85,27 +88,39 @@ template < int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large (!IsEpilogueFunctorHeavy::value) > -class Epilogue : +class Epilogue : public EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, Padding_, - FragmentsPerPartition> { + FragmentsPerPartition>, + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ public: using Base = EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, Padding_, FragmentsPerPartition>; + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + using Shape = Shape_; using WarpMmaOperator = WarpMmaOperator_; static int const kPartitionsK = PartitionsK; @@ -115,15 +130,23 @@ public: using SharedLoadIterator = SharedLoadIterator_; using OutputOp = OutputOp_; using Padding = Padding_; - using Layout = layout::RowMajor; using LongIndex = typename Layout::LongIndex; - /// The complete warp-level accumulator tile + /// Number of warps per block + using WarpCount = typename Base::WarpCount; + + /// Number of threads per block + static int const kBlockThreads = 32 * WarpCount::kCount; + + /// Per-thread accumulator tile type using AccumulatorTile = typename Base::AccumulatorTile; - /// Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; + /// Numerical accumulation element type + using ElementAccumulator = typename WarpMmaOperator::ElementC; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; /// Output element using ElementOutput = typename OutputTileIterator::Element; @@ -140,21 +163,20 @@ public: /// Const tensor reference to source tensor using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - /// Array type used to output + /// Vector type used by the global output iterator using OutputAccessType = Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; - /// Array type used by output functor - using AccumulatorAccessType = Array; - - /// Number of warps - using WarpCount = typename Base::WarpCount; + /// Vector type used by the shared output iterator + using AccumulatorAccessType = Array; static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; public: + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, "Mismatch between shared load iterator and output tile iterator."); @@ -163,144 +185,177 @@ public: static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), "Divisibility"); + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + private: /// Loads fragment from shared memory aligned with output tensor SharedLoadIterator shared_load_iterator_; + /// Thread index in the threadblock + int thread_idx; + + /// Warp index in the threadblock + int warp_idx; + public: /// Constructor CUTLASS_DEVICE Epilogue( - typename Base::SharedStorage &shared_storage, ///< Shared storage object - int thread_idx, ///< ID of a thread within the threadblock - int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), - shared_load_iterator_(shared_storage.reference(), thread_idx) + typename Base::SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx) ///< Id of thread within warp + : + Base(shared_storage, thread_idx, warp_idx, lane_idx), + BaseStreamK(thread_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx), + thread_idx(thread_idx), + warp_idx(warp_idx) + {} + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + ElementAccumulator *element_workspace, + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) { - + // Redcuce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Store fragment to shared memory + this->warp_tile_iterator_.store(accum_fragment); + + __syncthreads(); + + // Initialize/load source-fragment data + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + if (output_op.is_source_needed()) + { + source_iterator += reduce_fragment_idx; + source_iterator.load(source_fragment); + } + + // Load fragment from shared memory + typename SharedLoadIterator::Fragment aligned_accum_fragment; + shared_load_iterator_.load(aligned_accum_fragment); + + // Add fragments shared by other k partitions + if (kPartitionsK > 1) + { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + typename SharedLoadIterator::Fragment aligned_addend_fragment; + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_addend_fragment); + aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_addend_fragment); + } + } + + // Compute the output result + typename OutputTileIterator::Fragment output_fragment; + + // Apply the output operator + apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); + + // Store the final result + destination_iterator += reduce_fragment_idx; + destination_iterator.store(output_fragment); } + /// Streams the result to global memory CUTLASS_DEVICE void operator()( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - - if (!output_op.is_source_needed()) { - compute_source_not_needed_(output_op, destination_iterator, accumulators); + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator ) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + { + if (!output_op.is_source_needed()) + { + source_iterator.clear_mask(); + __syncthreads(); // Dummy (CUDA 11.0) } - else { - compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); - } - } -private: + // Source-fragment data (zero-initialized for scenarios where the + // output operator allows us to skip loading it from global input) + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); - template - struct acc2smem_source_not_needed; + // Iterator over warp-level accumulator fragment + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - template - struct acc2smem_source_not_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) + { + + // + // Convert and store fragment + // + + __syncthreads(); CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) + { typename AccumulatorFragmentIterator::Fragment accum_fragment; accum_fragment_iterator.load(accum_fragment); ++accum_fragment_iterator; - warp_tile_iterator.store(accum_fragment); + this->warp_tile_iterator_.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset); } } if (Base::kFragmentsPerIteration > 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * - (1 - Base::kFragmentsPerIteration)); + this->warp_tile_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); } - } - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { - int dummy[] = { - (pos == (Seq * Base::kFragmentsPerIteration)) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - - CUTLASS_UNUSED(dummy[0]); - } - }; - - static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_not_needed_( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators ///< Complete warp-level accumulator tile - ) { - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { - - // - // Convert and store fragment - // - - __syncthreads(); - - - acc2smem_source_not_needed< - cutlass::make_index_sequence>::push(iter, - accum_fragment_iterator, - this->warp_tile_iterator_); - - __syncthreads(); // // Load fragments from shared memory // - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + __syncthreads(); + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) + { + // Load addend source fragment from global memory + source_iterator.load(source_fragment); + ++source_iterator; typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; shared_load_iterator_.load(aligned_accum_fragment[0]); - if (p < Base::kFragmentsPerIteration - 1) { + if (p < Base::kFragmentsPerIteration - 1) + { shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); } - else if (kPartitionsK > 1) { - + else if (kPartitionsK > 1) + { plus add_fragments; CUTLASS_PRAGMA_UNROLL @@ -318,9 +373,7 @@ private: // typename OutputTileIterator::Fragment output_fragment; - - apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment[0]); - + apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0], source_fragment); // // Store the final result @@ -336,170 +389,37 @@ private: } } - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE - static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { - int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_needed_( - OutputOp const &output_op, ///< Output operator - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - ) { - - typename OutputTileIterator::Fragment source_fragment; - - source_fragment.clear(); - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - - // - // Load the source - // - - source_iterator.load(source_fragment); - ++source_iterator; - - // - // Convert and store fragment - // - - __syncthreads(); - - acc2smem_source_needed>::push( - iter, accum_fragment_iterator, this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - // If the number of k-slices is > 1 - perform a reduction amongst the k-slices - if (kPartitionsK > 1) { - - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); - } - - shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); - } - - // - // Compute the output result - // - - typename OutputTileIterator::Fragment output_fragment; - - apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment); - - - // - // Store the final result - // - - destination_iterator.store(output_fragment); - ++destination_iterator; - - } - } +private: /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE - void apply_output_operator_( + void apply_output_operator( typename OutputTileIterator::Fragment &output_fragment, OutputOp const &output_op, ///< Output operator typename SharedLoadIterator::Fragment const &aligned_accum_fragment, - typename OutputTileIterator::Fragment const &source_fragment) { - - OutputAccessType *output_frag_ptr = + typename OutputTileIterator::Fragment const &source_fragment) + { + + OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); - AccumulatorAccessType const *compute_frag_ptr = + AccumulatorAccessType const *compute_frag_ptr = reinterpret_cast(&aligned_accum_fragment); - OutputAccessType const *source_frag_ptr = + OutputAccessType const *source_frag_ptr = reinterpret_cast(&source_fragment); - int const kOutputOpIterations = + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - + for (int i = 0; i < kOutputOpIterations; ++i) + { // Call the output operator output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); } } - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_source_not_needed_( - typename OutputTileIterator::Fragment &output_fragment, - OutputOp const &output_op, ///< Output operator - typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { - - OutputAccessType *output_frag_ptr = - reinterpret_cast(&output_fragment); - - AccumulatorAccessType const *compute_frag_ptr = - reinterpret_cast(&aligned_accum_fragment); - - int const kOutputOpIterations = - OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - - // Call the output operator - output_frag_ptr[i] = output_op(compute_frag_ptr[i]); - } - } }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h b/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h new file mode 100644 index 00000000..45b0fd27 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_base_streamk.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Basic subset of epilogue functionality for supporting StreamK decompositions +*/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/block_striped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + + +/// StreamK epilogue functionality for cross-block accumulator fragment reduction +template < + typename Shape, ///< Shape of threadblock tile (concept: GemmShape) + int PartitionsK, + typename WarpMmaOperator, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + typename AccumulatorFragmentIterator> ///< Fragment iterator selecting accumulators +class EpilogueBaseStreamK +{ + +protected: + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Number of warps + using WarpCount = gemm::GemmShape< + Shape::kM / WarpMmaOperator::Shape::kM, + Shape::kN / WarpMmaOperator::Shape::kN, PartitionsK>; + + /// Number of threads per block + static int const kBlockThreads = 32 * WarpCount::kCount; + + /// Numerical accumulation element type + using ElementAccumulator = typename WarpMmaOperator::ElementC; + + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; + + /// Block-striped transfer utility for sharing AccumulatorFragment + using BlockStripedT = BlockStriped; + + /// Number of elements per fragment + static int const kFragmentElements = sizeof(AccumulatorFragment) / sizeof(ElementAccumulator); + +public: + + /// Number of fragments per accumulator tile + static int const kAccumulatorFragments = AccumulatorFragmentIterator::Policy::kIterations; + + /// Number of workspace accumulation elements shared per output tile + static int const kPeerAccumulators = WarpMmaOperator::Shape::kMN * WarpCount::kCount; + +protected: + + /// ElementAccumulator stride in the shared workspace between different peer blocks (two: each peer block can share accumulators for up to two tiles) + static const int kPeerStride = kPeerAccumulators * 2; + + +public: + + /// Thread index in the threadblock + int thread_idx; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueBaseStreamK( + int thread_idx) ///< ID of a thread within the threadblock + : + thread_idx(thread_idx) + {} + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace + CUTLASS_DEVICE + void reduce( + AccumulatorFragment &accum_fragment, ///< [out] sum of all shared accumulator fragments for these peer partials + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + ElementAccumulator *element_workspace) + { + plus add_fragments; + + int accum_set_offset = + (peer_idx_begin * kPeerStride) + + (reduce_fragment_idx * kBlockThreads * kFragmentElements); + + // Load first peer fragment + BlockStripedT::load(accum_fragment, element_workspace + accum_set_offset, this->thread_idx); + + accum_set_offset += kPeerStride; // Move to next peer + accum_set_offset += kPeerAccumulators; // Move to non-starting accumulator set for peer + + // Reduce additional peer fragments + #pragma unroll 2 + while (accum_set_offset < peer_idx_end * kPeerStride) + { + AccumulatorFragment addend_fragment; + BlockStripedT::load(addend_fragment, element_workspace + accum_set_offset, this->thread_idx); + accum_set_offset += kPeerStride; + + accum_fragment = add_fragments(accum_fragment, addend_fragment); + } + } + + + /// Shares the accumulator set with peers in the global workspace + CUTLASS_DEVICE + void share( + int peer_idx, + ElementAccumulator *element_workspace, ///< Output pointer for writing this block's accumulator set to + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + bool started_tile) + { + int accum_set_offset = peer_idx * kPeerStride; + + if (!started_tile) { + // Move to non-starting accumulator set + accum_set_offset += kPeerAccumulators; + } + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kAccumulatorFragments; ++iter) + { + // Acquire reordered accumulator fragment + AccumulatorFragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + // Store accumulator fragment + BlockStripedT::store(element_workspace + accum_set_offset, accum_fragment, this->thread_idx); + + accum_set_offset += (kFragmentElements * kBlockThreads); + } + } + +}; + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_depthwise.h b/include/cutlass/epilogue/threadblock/epilogue_depthwise.h new file mode 100644 index 00000000..1d013a31 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_depthwise.h @@ -0,0 +1,335 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for Depthwise convoltuion + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/reduction_op.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template +class EpilogueDepthwise { + public: + using Shape = Shape_; + using WarpShape = typename WarpMmaOperator_::Shape; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using WarpMmaOperator = WarpMmaOperator_; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = + Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Number of warps + using WarpCount = + gemm::GemmShape; + + public: + static_assert(SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + /// Shared storage allocation needed by the epilogue + struct SharedStorage { + // + // Type definitions + // + + /// Element type of shared memory + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Layout of shared memory allocation + using Layout = typename WarpTileIterator::Layout; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = MatrixShape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape; + + // + // Data members + // + + AlignedBuffer storage; + + // + // Methods + // + + /// Returns a pointer to the shared memory buffer + CUTLASS_DEVICE + Element *data() { return storage.data(); } + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference() { + return TensorRef(storage.data(), Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator_; + + LongIndex warp_offset; + int thread_idx; + int warp_idx; + int lane_idx; + int warp_m, warp_n; // warp coordinates within a cta + int tid_m, tid_n; // thread coordinates within a warp + + public: + /// Constructor + CUTLASS_DEVICE + EpilogueDepthwise(SharedStorage &shared_storage, ///< Shared storage object + int thread_idx_, ///< ID of a thread within the threadblock + int warp_idx_, ///< ID of warp within threadblock + int lane_idx_ ///< Id of thread within warp + ) + : thread_idx(thread_idx_), + warp_idx(warp_idx_), + lane_idx(lane_idx_), + shared_load_iterator_(shared_storage.reference(), thread_idx_), + warp_tile_iterator_(shared_storage.reference(), thread_idx_, lane_idx_) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()(OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in + ///< units of threadblock tiles) + const int smem_base_offset) { ///< SMEM base offset for epilogue operation + // initiate the smem base offset for different output tile. + warp_tile_iterator_.set_smem_base_address(smem_base_offset); + + shared_load_iterator_.set_smem_base_address(smem_base_offset); + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_(output_op, destination_iterator, accumulators, source_iterator); + } + } + + private: + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; + + source_fragment.clear(); + + source_iterator.load(source_fragment); + + // store to smem + warp_tile_iterator_.store(accumulators); + + __syncthreads(); + + typename SharedLoadIterator::Fragment aligned_accum_fragment; + + // load from smem + shared_load_iterator_.load(aligned_accum_fragment); + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_(output_fragment, output_op, aligned_accum_fragment, source_fragment); + + // Store to GMEM + destination_iterator.store(output_fragment); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + // store to smem + warp_tile_iterator_.store(accumulators); + + __syncthreads(); + + typename SharedLoadIterator::Fragment aligned_accum_fragment; + + // load from smem + shared_load_iterator_.load(aligned_accum_fragment); + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_(output_fragment, output_op, aligned_accum_fragment); + + // Store to GMEM + destination_iterator.store(output_fragment); + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment) { + + OutputAccessType *output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + OutputAccessType const *source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, ///< Output operator + typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { + OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); + + AccumulatorAccessType const *compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = output_op(compute_frag_ptr[i]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_direct_store.h b/include/cutlass/epilogue/threadblock/epilogue_direct_store.h index 4bbfafaf..5ce9719a 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_direct_store.h +++ b/include/cutlass/epilogue/threadblock/epilogue_direct_store.h @@ -77,7 +77,6 @@ public: using OutputTileIterator = OutputTileIterator_; using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; using OutputOp = OutputOp_; using Padding = MatrixShape<0, 0>; diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 3fafcbc4..d3a7d7e2 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -133,7 +133,8 @@ struct EpilogueWithBroadcastOpBase { FragmentZ &frag_Z, FragmentT &frag_T, FragmentAccumulator const &AB, - FragmentC const &frag_C, + FragmentC const &frag_C1, + FragmentC const &frag_C2, FragmentCompute const &V) const { } @@ -180,9 +181,42 @@ template < typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large - (!IsEpilogueFunctorHeavy::value) + (!IsEpilogueFunctorHeavy::value), + bool IsSingleSource = OutputOp_::kIsSingleSource > -class EpilogueWithBroadcast : +class EpilogueWithBroadcast; + +template < + typename Shape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputTileIterator_, + typename TensorTileIterator_, + typename ElementVector_, + typename AccumulatorFragmentIterator_, + typename WarpTileIterator_, + typename SharedLoadIterator_, + typename OutputOp_, + typename Padding_, + int FragmentsPerPartition, + int IterationsUnroll +> +class EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + false +> : public EpilogueBase< Shape_, typename WarpMmaOperator_::Shape, @@ -203,6 +237,7 @@ public: Padding_, FragmentsPerPartition>; + static bool const kIsSingleSource = false; using Shape = Shape_; using WarpMmaOperator = WarpMmaOperator_; static int const kPartitionsK = PartitionsK; @@ -383,7 +418,687 @@ public: CUTLASS_DEVICE void operator()( OutputOp const &output_op, ///< Output operator - ElementVector const * broadcast_ptr, ///< Broadcast vector + ElementVector const * broadcast_ptr, ///< Broadcast vector + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix + OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) { + + BroadcastFragment broadcast_fragment; + + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + if (!output_op.is_source_needed()) { + compute_source_not_needed_( + output_op, + broadcast_fragment, + destination_iterator, + accumulators, + tensor_iterator); + } + else { + compute_source_needed_( + output_op, + broadcast_fragment, + destination_iterator, + accumulators, + source_iterator1, + source_iterator2, + tensor_iterator); + } + } + +private: + + CUTLASS_DEVICE + void load_broadcast_fragment_( + BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const * broadcast_ptr, ///< Broadcast vector + MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space + ) { + + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { + return; + } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter converter; + using AccessType = AlignedArray; + using ComputeFragmentType = Array; + + ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + ComputeFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } + + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + // CUTLASS_PRAGMA_UNROLL + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { + + // + // Convert and store fragment + // + + + __syncthreads(); + + acc2smem_source_not_needed< + cutlass::make_index_sequence>::push(iter, + accum_fragment_iterator, + this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } + else if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + apply_output_operator_source_not_needed_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + broadcast_fragment); + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator1, ///< Tile iterator for first source accumulator matrix + OutputTileIterator source_iterator2, ///< Tile iterator for second source accumulator matrix + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + typename OutputTileIterator::Fragment source_fragment1; + source_fragment1.clear(); + typename OutputTileIterator::Fragment source_fragment2; + source_fragment2.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + source_iterator1.load(source_fragment1); + ++source_iterator1; + + source_iterator2.load(source_fragment2); + ++source_iterator2; + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + apply_output_operator_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + source_fragment1, + source_fragment2, + broadcast_fragment); + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + typename OutputTileIterator::Fragment const &frag_C1, + typename OutputTileIterator::Fragment const &frag_C2, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + OutputAccessType const *frag_C1_ptr = + reinterpret_cast(&frag_C1); + + OutputAccessType const *frag_C2_ptr = + reinterpret_cast(&frag_C2); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_C1_ptr[i], + frag_C2_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } +}; + + +template < + typename Shape_, + typename WarpMmaOperator_, + int PartitionsK, + typename OutputTileIterator_, + typename TensorTileIterator_, + typename ElementVector_, + typename AccumulatorFragmentIterator_, + typename WarpTileIterator_, + typename SharedLoadIterator_, + typename OutputOp_, + typename Padding_, + int FragmentsPerPartition, + int IterationsUnroll +> +class EpilogueWithBroadcast< + Shape_, + WarpMmaOperator_, + PartitionsK, + OutputTileIterator_, + TensorTileIterator_, + ElementVector_, + AccumulatorFragmentIterator_, + WarpTileIterator_, + SharedLoadIterator_, + OutputOp_, + Padding_, + FragmentsPerPartition, + IterationsUnroll, + true +> : + public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + static bool const kIsSingleSource = true; + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = Array< + ElementCompute, + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + /// Used for the broadcast + struct BroadcastDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + kThreadRows, + Shape::kN + >; + + /// Debug printing + CUTLASS_DEVICE + static void print() { +#if 0 + printf("BroadcastDetail {\n"); + printf( + " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); +#endif + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + }; + + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + +public: + + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index within the threadblock + int thread_idx_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueWithBroadcast( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) + { + + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix @@ -646,7 +1361,7 @@ private: BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand ) { @@ -759,7 +1474,7 @@ private: AccumulatorAccessType const *frag_AB_ptr = reinterpret_cast(&frag_AB); - OutputAccessType const *frag_C_ptr = + OutputAccessType const *frag_C_ptr = reinterpret_cast(&frag_C); AccessTypeBroadcast const *frag_Broadcast_ptr = @@ -770,13 +1485,12 @@ private: CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kOutputOpIterations; ++i) { - - output_op( - frag_Z_ptr[i], - frag_T_ptr[i], - frag_AB_ptr[i], - frag_C_ptr[i], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); } } diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h deleted file mode 100644 index 65b08638..00000000 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h +++ /dev/null @@ -1,847 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. -*/ - -#pragma once - -#include -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_coord.h" -#include "cutlass/aligned_buffer.h" -#include "cutlass/functional.h" -#include "cutlass/fast_math.h" -#include "cutlass/layout/vector.h" -#include "cutlass/layout/tensor.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator.h" - -#include "cutlass/epilogue/threadblock/epilogue_base.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h" - -#include "cutlass/util/index_sequence.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template < - typename ElementC_, - typename ElementAccumulator_, - typename ElementCompute_, - typename ElementZ_, - typename ElementT_, - int ElementsPerAccess, - bool StoreZ = true, - bool StoreT = true -> -struct EpilogueWithBroadcastOpBaseV2 { - - using ElementOutput = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = StoreZ; - - /// If true, the 'T' tensor is stored - static bool const kStoreT = StoreT; - - /// Parameters structure - required - struct Params { }; - - // - // Methods - // - - /// Constructor from Params - EpilogueWithBroadcastOpBaseV2(Params const ¶ms_) { } - - /// Determine if the source is needed. May return false if - bool is_source_needed() const { - return true; - } - - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { } - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentC const &frag_C1, - FragmentC const &frag_C2, - FragmentCompute const &V) const { - - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentCompute const &V) const { - - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Epilogue operator with bias vector broadcast over columns. -/// -/// Computes the following: -/// -/// -/// Z, T = OutputOp(AB, C, Broadcast) -/// -/// if (ElementwiseOp::kStoreZ) { -/// store(converted_u); -/// } -/// -/// if (ElementwiseOp::kStoreT) { -/// store(v); -/// } -/// -template < - typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) - int PartitionsK, ///< Number of partitions of the K dimension - typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) - typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) - typename ElementVector_, ///< Pointer to broadcast vector - typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators - typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM - typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM - typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp - typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) - int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity - int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large - (!IsEpilogueFunctorHeavy::value) -> -class EpilogueWithBroadcastV2 : - public EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, - Padding_, - FragmentsPerPartition> { - -public: - - using Base = EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, - Padding_, - FragmentsPerPartition>; - - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - using OutputTileIterator = OutputTileIterator_; - using TensorTileIterator = TensorTileIterator_; - using ElementVector = ElementVector_; - using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp = OutputOp_; - using Padding = Padding_; - - using Layout = layout::RowMajor; - using LongIndex = typename Layout::LongIndex; - - /// The complete warp-level accumulator tile - using AccumulatorTile = typename Base::AccumulatorTile; - - /// Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; - - /// Compute data type produced by the output op - using ElementCompute = typename OutputOp::ElementCompute; - - /// Compute fragment - using FragmentCompute = Array; - - /// Thread map used by output tile iterators - using ThreadMap = typename OutputTileIterator::ThreadMap; - - /// Fragment object used to store the broadcast values - using BroadcastFragment = Array< - ElementCompute, - ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; - - /// Output element - using ElementOutput = typename OutputTileIterator::Element; - - /// Data type of additional tensor - using ElementTensor = typename TensorTileIterator::Element; - - /// Output access size - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - /// Tensor reference to destination tensor - using TensorRef = typename OutputTileIterator::TensorRef; - - /// Tensor reference to sync tensor - using SyncTensorRef = typename cutlass::TensorRef; - - /// Const tensor reference to source tensor - using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - - /// Array type used to output - using OutputAccessType = Array< - typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; - - /// Array type used by output functor - using AccumulatorAccessType = Array; - - /// Array type used by output functor - using ComputeAccessType = Array; - - /// Tensor access type - using TensorAccessType = Array; - - /// Number of warps - using WarpCount = typename Base::WarpCount; - - /// Shared memory allocation from epilogue base class - using BaseSharedStorage = typename Base::SharedStorage; - - static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; - static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; - - /// Used for the broadcast - struct BroadcastDetail { - - /// Number of threads per warp - static int const kWarpSize = 32; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar column indices handled by each thread - static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar row indices handled by each thread - static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; - - /// Number of threads per threadblock - static int const kThreadCount = kWarpSize * WarpCount::kCount; - - /// Number of distinct threads per row of output tile - static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); - - /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. - static int const kThreadRows = kThreadCount / kThreadsPerRow; - - /// I'm not sure what I meant here. - static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); - - /// Shape of the shared memory allocation for the epilogue - using StorageShape = MatrixShape< - kThreadRows, - Shape::kN - >; - - /// Debug printing - CUTLASS_DEVICE - static void print() { -#if 0 - printf("BroadcastDetail {\n"); - printf( - " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" - "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", - kColumnsPerThread, - kRowsPerThread, - kThreadCount, - kThreadsPerRow, - kThreadRows, - kThreadAccessesPerRow, - StorageShape::kRow, - StorageShape::kColumn, - StorageShape::kCount - ); - printf("};\n"); -#endif - } - }; - - /// Shared storage structure (shadows base) with additional SMEM buffer for reduction - struct SharedStorage { - union { - BaseSharedStorage base; - }; - - CUTLASS_HOST_DEVICE - SharedStorage() { } - }; - -public: - - - static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); - - static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), - "Divisibility"); - -private: - - /// Loads fragment from shared memory aligned with output tensor - SharedLoadIterator shared_load_iterator_; - - /// Thread index within the threadblock - int thread_idx_; - -public: - - /// Constructor - CUTLASS_DEVICE - EpilogueWithBroadcastV2( - SharedStorage &shared_storage, ///< Shared storage object - int thread_idx, ///< ID of a thread within the threadblock - int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ): - Base(shared_storage.base, thread_idx, warp_idx, lane_idx), - shared_load_iterator_(shared_storage.base.reference(), thread_idx), - thread_idx_(thread_idx) - { - - } - - /// Streams the result to global memory - CUTLASS_DEVICE - void operator()( - OutputOp const &output_op, ///< Output operator - ElementVector const * broadcast_ptr, ///< Broadcast vector - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator1, ///< Tile iterator for source accumulator matrix - OutputTileIterator source_iterator2, ///< Tile iterator for source accumulator matrix - TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand - MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord(Shape::kM, Shape::kN), - MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space - MatrixCoord()) { - - BroadcastFragment broadcast_fragment; - - load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); - - if (!output_op.is_source_needed()) { - compute_source_not_needed_( - output_op, - broadcast_fragment, - destination_iterator, - accumulators, - tensor_iterator); - } - else { - compute_source_needed_( - output_op, - broadcast_fragment, - destination_iterator, - accumulators, - source_iterator1, - source_iterator2, - tensor_iterator); - } - } - -private: - - CUTLASS_DEVICE - void load_broadcast_fragment_( - BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - ElementVector const * broadcast_ptr, ///< Broadcast vector - MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space - ) { - - broadcast_fragment.clear(); - - // If no pointer is supplied, set with all zeros and avoid memory accesses - if (!broadcast_ptr) { - return; - } - - int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); - - int thread_column_idx = threadblock_offset.column() + thread_initial_column; - broadcast_ptr += thread_initial_column; - - NumericArrayConverter converter; - using AccessType = AlignedArray; - using ComputeFragmentType = Array; - - ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { - - AccessType loaded; - - loaded.clear(); - - if (thread_column_idx < problem_size.column()) { - loaded = *reinterpret_cast(broadcast_ptr); - } - - ComputeFragmentType cvt = converter(loaded); - frag_ptr[j] = cvt; - - thread_column_idx += ThreadMap::Delta::kColumn; - broadcast_ptr += ThreadMap::Delta::kColumn; - } - } - - template - struct acc2smem_source_not_needed; - - template - struct acc2smem_source_not_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - warp_tile_iterator.store(accum_fragment); - if (p < Base::kFragmentsPerIteration - 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); - } - } - - if (Base::kFragmentsPerIteration > 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * - (1 - Base::kFragmentsPerIteration)); - } - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { - int dummy[] = { - (pos == (Seq * Base::kFragmentsPerIteration)) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - - CUTLASS_UNUSED(dummy[0]); - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_not_needed_( - OutputOp const &output_op, ///< Output operator - BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) { - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - - // CUTLASS_PRAGMA_UNROLL - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { - - // - // Convert and store fragment - // - - - __syncthreads(); - - acc2smem_source_not_needed< - cutlass::make_index_sequence>::push(iter, - accum_fragment_iterator, - this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - if (p < Base::kFragmentsPerIteration - 1) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - } - else if (kPartitionsK > 1) { - - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); - } - - shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); - } - - // - // Apply output operation - // - - typename OutputTileIterator::Fragment frag_Z; - typename TensorTileIterator::Fragment frag_T; - - apply_output_operator_source_not_needed_( - frag_Z, - frag_T, - output_op, - aligned_accum_fragment[0], - broadcast_fragment); - - // - // Conditionally store fragments - // - - if (OutputOp::kStoreZ) { - destination_iterator.store(frag_Z); - ++destination_iterator; - } - - if (OutputOp::kStoreT) { - tensor_iterator.store(frag_T); - ++tensor_iterator; - } - } - - if (Base::kFragmentsPerIteration > 1) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); - } - } - } - - - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE - static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { - int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; - - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_needed_( - OutputOp const &output_op, ///< Output operator - BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator1, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - OutputTileIterator source_iterator2, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) { - - typename OutputTileIterator::Fragment source_fragment1; - source_fragment1.clear(); - typename OutputTileIterator::Fragment source_fragment2; - source_fragment2.clear(); - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - - // - // Load the source - // - - source_iterator1.load(source_fragment1); - ++source_iterator1; - - if (source_iterator2.enabled()) { - source_iterator2.load(source_fragment2); - ++source_iterator2; - } - - // - // Convert and store fragment - // - - __syncthreads(); - - acc2smem_source_needed>::push( - iter, accum_fragment_iterator, this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - // If the number of k-slices is > 1 - perform a reduction amongst the k-slices - if (kPartitionsK > 1) - { - plus add_fragments; - const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); - } - - shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); - } - - // - // Apply output operation - // - - typename OutputTileIterator::Fragment frag_Z; - typename TensorTileIterator::Fragment frag_T; - - apply_output_operator_( - frag_Z, - frag_T, - output_op, - aligned_accum_fragment[0], - source_fragment1, - source_fragment2, - broadcast_fragment, - source_iterator2.enabled()); - // - // Conditionally store fragments - // - - if (OutputOp::kStoreZ) { - destination_iterator.store(frag_Z); - ++destination_iterator; - } - - if (OutputOp::kStoreT) { - tensor_iterator.store(frag_T); - ++tensor_iterator; - } - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_( - typename OutputTileIterator::Fragment &frag_Z, - typename TensorTileIterator::Fragment &frag_T, - OutputOp const &output_op, - typename SharedLoadIterator::Fragment const &frag_AB, - typename OutputTileIterator::Fragment const &frag_C1, - typename OutputTileIterator::Fragment const &frag_C2, - BroadcastFragment const &frag_Broadcast, - bool frag_C2_enabled) { - - using AccessTypeZ = Array; - using AccessTypeT = Array; - using AccessTypeBroadcast = Array; - - AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); - AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const *frag_AB_ptr = - reinterpret_cast(&frag_AB); - - OutputAccessType const *frag_C1_ptr = - reinterpret_cast(&frag_C1); - OutputAccessType const *frag_C2_ptr = - reinterpret_cast(&frag_C2); - - AccessTypeBroadcast const *frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); - - int const kOutputOpIterations = - OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - if (frag_C2_enabled) { - output_op( - frag_Z_ptr[i], - frag_T_ptr[i], - frag_AB_ptr[i], - frag_C1_ptr[i], - frag_C2_ptr[i], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } else { - output_op( - frag_Z_ptr[i], - frag_T_ptr[i], - frag_AB_ptr[i], - frag_C1_ptr[i], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_source_not_needed_( - typename OutputTileIterator::Fragment &frag_Z, - typename TensorTileIterator::Fragment &frag_T, - OutputOp const &output_op, - typename SharedLoadIterator::Fragment const &frag_AB, - BroadcastFragment const &frag_Broadcast) { - - using AccessTypeZ = Array; - using AccessTypeT = Array; - using AccessTypeBroadcast = Array; - - AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); - AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const *frag_AB_ptr = - reinterpret_cast(&frag_AB); - - AccessTypeBroadcast const *frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); - - int const kOutputOpIterations = - OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - - output_op( - frag_Z_ptr[i], - frag_T_ptr[i], - frag_AB_ptr[i], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index 685777c6..99a534f0 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -124,6 +124,8 @@ public: using Layout = layout::RowMajor; using LongIndex = typename Layout::LongIndex; + static bool const kIsSingleSource = true; + /// The complete warp-level accumulator tile using AccumulatorTile = typename Base::AccumulatorTile; @@ -294,7 +296,7 @@ public: CUTLASS_DEVICE void operator()( OutputOp const &output_op, ///< Output operator - ElementVector * reduction_output_ptr, ///< Reduction output vector + ElementVector * reduction_output_ptr, ///< Reduction output vector OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix diff --git a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h index 685b512e..0bf9cddf 100644 --- a/include/cutlass/epilogue/threadblock/interleaved_epilogue.h +++ b/include/cutlass/epilogue/threadblock/interleaved_epilogue.h @@ -51,7 +51,7 @@ #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" -#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/epilogue_base_streamk.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" //////////////////////////////////////////////////////////////////////////////// @@ -78,8 +78,21 @@ template < typename OutputOp_, /// Number of interleaved k int InterleavedK> -class InterleavedEpilogue { - public: +class InterleavedEpilogue : + public EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_> +{ +public: + + using BaseStreamK = EpilogueBaseStreamK< + Shape_, + PartitionsK, + WarpMmaOperator_, + AccumulatorFragmentIterator_>; + using Shape = Shape_; using WarpMmaOperator = WarpMmaOperator_; static int const kPartitionsK = PartitionsK; @@ -90,6 +103,9 @@ class InterleavedEpilogue { /// The complete warp-level accumulator tile using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + /// Fragment type used by the accumulator tile's fragment iterator + using AccumulatorFragment = typename AccumulatorFragmentIterator::Fragment; + /// Accumulator element using ElementAccumulator = typename AccumulatorTile::Element; @@ -122,7 +138,8 @@ class InterleavedEpilogue { gemm::GemmShape; - public: +public: + static_assert(OutputTileIterator::kElementsPerAccess, "This must not be zero."); @@ -134,15 +151,58 @@ class InterleavedEpilogue { struct SharedStorage {}; - public: +public: + /// Constructor CUTLASS_DEVICE InterleavedEpilogue( SharedStorage &shared_storage, ///< Shared storage object int thread_idx, ///< ID of a thread within the threadblock int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ) {} + int lane_idx) ///< Id of thread within warp + : + BaseStreamK(thread_idx) + {} + + + /// Aggregates the accumulator sets shared by peer blocks in the global workspace, + /// performing epilogue computations, writing to output + CUTLASS_DEVICE + void reduce( + int peer_idx_begin, + int peer_idx_end, + int reduce_fragment_idx, + ElementAccumulator *element_workspace, + OutputOp const &output_op, ///< Output operator + OutputTileIterator destination_iterator, ///< Tile iterator for destination + OutputTileIterator source_iterator) ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + { + // Redcuce peer accumulator fragments into one fragment + AccumulatorFragment accum_fragment; + BaseStreamK::reduce(accum_fragment, peer_idx_begin, peer_idx_end, reduce_fragment_idx, element_workspace); + + // Source-fragment data (zero-initialized for scenarios where the + // output operator allows us to skip loading it from global input) + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + if (output_op.is_source_needed()) + { + source_iterator += reduce_fragment_idx; + source_iterator.load(source_fragment); + } + + // Compute the output result + typename OutputTileIterator::Fragment output_fragment; + + // Apply the output operator + apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); + + // Store the final result + destination_iterator += reduce_fragment_idx; + destination_iterator.store(output_fragment); + } + /// Streams the result to global memory CUTLASS_DEVICE @@ -194,7 +254,7 @@ class InterleavedEpilogue { // typename OutputTileIterator::Fragment output_fragment; - apply_output_operator_source_not_needed_(output_op, output_fragment, accum_fragment); + apply_output_operator_source_not_needed(output_fragment, output_op, accum_fragment); // // Store the final result @@ -257,7 +317,7 @@ class InterleavedEpilogue { // typename OutputTileIterator::Fragment output_fragment; - apply_output_operator_source_needed_(output_op, output_fragment, accum_fragment, source_fragment); + apply_output_operator(output_fragment, output_op, accum_fragment, source_fragment); // // Store the final result @@ -269,15 +329,16 @@ class InterleavedEpilogue { } } - private: +protected: + /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE - void apply_output_operator_source_needed_( - OutputOp const &output_op, ///< Output operator - typename OutputTileIterator::Fragment &output_fragment, - typename AccumulatorFragmentIterator::Fragment const - &aligned_accum_fragment, - typename OutputTileIterator::Fragment const &source_fragment) { + void apply_output_operator( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment, + typename OutputTileIterator::Fragment const &source_fragment) + { OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); @@ -300,11 +361,11 @@ class InterleavedEpilogue { /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE - void apply_output_operator_source_not_needed_( - OutputOp const &output_op, ///< Output operator - typename OutputTileIterator::Fragment &output_fragment, - typename AccumulatorFragmentIterator::Fragment const - &aligned_accum_fragment) { + void apply_output_operator_source_not_needed( + typename OutputTileIterator::Fragment &output_fragment, + OutputOp const &output_op, + typename AccumulatorFragmentIterator::Fragment const &aligned_accum_fragment) + { OutputAccessType *output_frag_ptr = reinterpret_cast(&output_fragment); diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index f7e5c2fd..4bd6ad50 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -680,6 +680,9 @@ public: state_[2] = 0; byte_pointer_ += params_.advance_tile; store_byte_pointer_ += params_.advance_tile; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow + * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; } } } @@ -687,6 +690,60 @@ public: return *this; } + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + byte_pointer_ += (params_.advance_row * increment); + store_byte_pointer_ += (params_.advance_row * increment); + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + byte_pointer_ += (params_.advance_group * increment_row); + store_byte_pointer_ += (params_.advance_group * increment_row); + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * + increment_row; + + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + byte_pointer_ += (params_.advance_cluster * increment_group); + store_byte_pointer_ += (params_.advance_cluster * increment_group); + thread_start_row_ += + ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * + ThreadMap::Shape::kRow * + increment_group; + + // Tile + byte_pointer_ += (params_.advance_tile * increment_cluster); + store_byte_pointer_ += (params_.advance_tile * increment_cluster); + thread_start_row_ += + ThreadMap::Shape::kGroup * + ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * + ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + ///< Efficiently disables all accesses guarded by mask CUTLASS_DEVICE void clear_mask() { mask_.clear(); @@ -944,6 +1001,23 @@ public: return *this; } + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + InterleavedPredicatedTileIterator &operator+=(int increment) + { + // Contiguous + iteration_contiguous_ += increment; + int increment_strided = iteration_contiguous_ / ThreadMap::Iterations::kContiguous; + iteration_contiguous_ = iteration_contiguous_ % ThreadMap::Iterations::kContiguous; + byte_pointer_ += (params_.advance_row * increment); + + // Strided + iteration_strided_ += increment_strided; + byte_pointer_ += (params_.advance_column * increment_strided); + + return *this; + } + ///< Efficiently disables all accesses guarded by mask CUTLASS_DEVICE void clear_mask() { mask_.clear(); diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h new file mode 100644 index 00000000..6050a5bc --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h @@ -0,0 +1,445 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: PitchLinearThreadMap) + typename Element_, ///< Element data type + typename ThreadOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1>, + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +class PredicatedTileIteratorDirectConv { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + + using ConvProblemSize = typename cutlass::conv::Conv2dProblemSize; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + static int const kLoadsPerAccess = AccessType::kElements / AccessType::kElements; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorDirect2dConvParams { + using Base = PredicatedTileIteratorDirect2dConvParams; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, cutlass::conv::Conv2dProblemSize const &problem_size): + PredicatedTileIteratorDirect2dConvParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + problem_size, + {ThreadBlockOutputShape::kH, ThreadBlockOutputShape::kW} + ) + { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kContiguous; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorDirect2dConvParams params_; + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// + Element *pointer_; + + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Initial thread ouput location + int thread_start_n_, thread_start_p_, thread_start_q_; + + /// Current threadblock tile index + int tile_index_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorDirect2dConvParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + + + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorDirectConv( + PredicatedTileIteratorDirect2dConvParams const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params), pointer_(pointer) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + // stride dim (PQ) + thread_start_row_ = thread_offset.column(); + // contiguous dim (Channels) + thread_start_column_ = threadblock_offset.column() + thread_offset.row(); + + tile_index_ = threadblock_offset.row(); + + set_tile_index(0); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void set_tile_index(const int index) { + + int residual; + params_.pq_divmod(thread_start_n_, residual, tile_index_ + index); + params_.q_divmod(thread_start_p_, thread_start_q_, residual); + + // Compute the base output coord of ThreadBlock + thread_start_p_ *= ThreadBlockOutputShape::kH; + thread_start_q_ *= ThreadBlockOutputShape::kW; + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + mask_.predicates[c] = ((thread_start_column_ + + c * ThreadMap::Delta::kContiguous) < extent_column_); + } + + // Null pointer performs no accesses + if (!pointer_) { + mask_.clear(); + } + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; + int p = current_row / ThreadBlockOutputShape::kW; + int q = current_row % ThreadBlockOutputShape::kW; + + int current_p = thread_start_p_ + p; + int current_q = thread_start_q_ + q; + + bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && + (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; + + int output_row_offset = + thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; + + uint8_t *byte_pointer = + reinterpret_cast(pointer_) + + LongIndex(output_row_offset) * LongIndex(params_.stride) + + LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * + sizeof(AccessType) / kElementsPerAccess; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + bool guard = row_guard && mask_.predicates[c]; + + cutlass::arch::global_load( + frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + int current_row = thread_start_row_ + s * ThreadMap::Delta::kStrided; + int p = current_row / ThreadBlockOutputShape::kW; + int q = current_row % ThreadBlockOutputShape::kW; + + int current_p = thread_start_p_ + p; + int current_q = thread_start_q_ + q; + + bool row_guard = (current_p) < params_.P && (current_q) < params_.Q && + (thread_start_n_ < params_.N) && current_row < ThreadMap::Shape::kStrided; + + int output_row_offset = + thread_start_n_ * params_.stride_n + current_p * params_.stride_p + current_q; + + uint8_t *byte_pointer = + reinterpret_cast(pointer_) + + LongIndex(output_row_offset) * LongIndex(params_.stride) + + LongIndex(thread_start_column_ + c * ThreadMap::Delta::kContiguous) * + sizeof(AccessType) / kElementsPerAccess; + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + bool guard = row_guard && mask_.predicates[c]; + + cutlass::arch::global_store( + frag_ptr[frag_base_idx], (void *)&memory_pointer[0], guard); + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirectConv &operator++() { + // do nothing + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index e3eaf55e..73db5432 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -35,9 +35,12 @@ #pragma once #include "cutlass/cutlass.h" + #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" +#include "cutlass/conv/conv2d_problem_size.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -245,6 +248,87 @@ struct PredicatedTileIteratorParams { }; + +/////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct for PredicatedTileIteratorDirect2dConv +// + +struct PredicatedTileIteratorDirect2dConvParams{ + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + FastDivmod pq_divmod; + FastDivmod q_divmod; + + LongIndex stride; + LongIndex stride_n; + LongIndex stride_p; + + int N; + int P; + int Q; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride_, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + stride = stride_; // The stride per row of output tensor (bytes) + stride_n = problem_size.P * problem_size.Q; + stride_p = problem_size.Q ; + + N = problem_size.N; + P = problem_size.P; + Q = problem_size.Q; + + // Fastdivmod for output O, P, Q + if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){ + int tiles_p = + (problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row()); + int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) / + (threadblock_output_shape.column()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Status initialize( + Index stride_, + cutlass::conv::Conv2dProblemSize const &problem_size = cutlass::conv::Conv2dProblemSize(), + MatrixCoord threadblock_output_shape = MatrixCoord()) { + return initialize(LongIndex(stride_), problem_size, threadblock_output_shape); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams() { initialize(LongIndex(0)); } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams(Index stride, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + initialize(stride, problem_size, threadblock_output_shape); + } + + CUTLASS_HOST_DEVICE + PredicatedTileIteratorDirect2dConvParams(LongIndex stride, + cutlass::conv::Conv2dProblemSize const &problem_size, + MatrixCoord threadblock_output_shape) { + initialize(stride, problem_size, threadblock_output_shape); + } +}; + /////////////////////////////////////////////////////////////////////////////// // InterleavedPredicatedTileIterator /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h deleted file mode 100644 index 0934baf7..00000000 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_v2.h +++ /dev/null @@ -1,1023 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) - typename Element_, ///< Element data type - bool UseCUDAStore = false -> -class PredicatedTileIteratorV2 { -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); - static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); - static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); - static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); - - /// Fragment object - using Fragment = Array< - Element, - ThreadMap::Iterations::kColumn * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Layout const &layout): - PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc() - ) - { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) : - Base(base) { } - }; - - /// Mask object - struct Mask { - - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { - enable(); - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - -private: - - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t *byte_pointer_{nullptr}; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// Internal state counter - int state_[3]; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - -private: - - // - // Methods - // - -public: - - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorV2( - PredicatedTileIteratorParams const & params, - Element *pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset = TensorCoord() - ): - params_(params) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - thread_start_row_ = thread_offset.row(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - - mask_.predicates[c] = ((thread_offset.column() - + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { - mask_.clear(); - } else { - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { - - uint8_t *byte_pointer = byte_pointer_; - AccessType *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + - column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / - kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - byte_pointer += params_.increment_row; - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) const { - - load_with_byte_offset(frag, 0); - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { - uint8_t *byte_pointer = byte_pointer_; - AccessType const *frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup - + cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - - bool guard = row_guard && mask_.predicates[column]; - - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - byte_pointer += params_.increment_row; - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) const { - - store_with_byte_offset(frag, 0); - } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { - return thread_start_row_; - } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { - return extent_row_; - } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorV2 &operator++() { - - ++state_[0]; - byte_pointer_ += params_.advance_row; - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * - ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * - ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { - mask_.clear(); - } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { - mask_.enable(); - } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask &mask) const { - mask = mask_; - } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const &mask) { - mask_ = mask; - } - - CUTLASS_DEVICE bool enabled() { - return (byte_pointer_ != nullptr); - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | InterleavedPredicatedTileIterator | ForwardTileIterator -/// -template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) - typename Element_, ///< Element data type - int InterleavedN ///< Number of Interleaved N -> -class InterleavedPredicatedTileIteratorV2 { -public: - using ThreadMap = ThreadMap_; - - using Element = Element_; - - using Layout = layout::ColumnMajorInterleaved; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = layout::PitchLinearCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Iterations::kCount; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - /// Uses a non-template class - struct Params : InterleavedPredicatedTileIteratorParams { - using Base = InterleavedPredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() { } - - CUTLASS_HOST_DEVICE - Params(Layout const &layout): - Base( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_InterleavedPredicatedTileIteratorDesc() - ) { } - - CUTLASS_HOST_DEVICE - Params(Base const &base) : - Base(base) { } - }; - - /// Mask object - struct Mask { - static int const kCount = (ThreadMap::Iterations::kContiguous < 8) - ? 8 - : ThreadMap::Iterations::kContiguous; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { - enable(); - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - -private: - - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - Params params_; - - /// Byte-level pointer - uint8_t *byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in columns - Index extent_col_; - - /// A thread's starting column position (assuming steady-state predicates have - /// been computed) - Index thread_start_col_; - - /// Internal iteration counter - int iteration_contiguous_; - - int iteration_strided_; - -private: - - // - // Methods - // - -public: - - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - InterleavedPredicatedTileIteratorV2( - Params const & params, - Element *pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset - ): - params_(params) { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + - TensorCoord(threadblock_offset.contiguous() * InterleavedN, - threadblock_offset.strided() / InterleavedN); - - extent_col_ = extent.strided() / InterleavedN; - thread_start_col_ = thread_offset.strided(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { - mask_.predicates[c] = - ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) < - (extent.contiguous() * InterleavedN)); - } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.strided()) * LongIndex(params_.stride) + - LongIndex(thread_offset.contiguous()) * sizeof(AccessType) / kElementsPerAccess; - - // Initialize internal state counter - iteration_contiguous_ = iteration_strided_ = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - - uint8_t *byte_pointer = byte_pointer_; - AccessType *frag_ptr = reinterpret_cast(&frag); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); - - int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; - - bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); - - bool guard = col_guard && mask_.predicates[iteration_contiguous_]; - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - *frag_ptr, - (void *)memory_pointer, - guard); - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - uint8_t *byte_pointer = byte_pointer_; - AccessType const *frag_ptr = reinterpret_cast(&frag); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); - - int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided; - - bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); - - bool guard = col_guard && mask_.predicates[iteration_contiguous_]; - - cutlass::arch::global_store( - *frag_ptr, (void *)memory_pointer, guard); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int iteration) { - iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous; - iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous; - } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - InterleavedPredicatedTileIteratorV2 &operator++() { - - ++iteration_contiguous_; - byte_pointer_ += params_.advance_row; - - if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) { - - iteration_contiguous_ = 0; - ++iteration_strided_; - byte_pointer_ += params_.advance_column; - - if (iteration_strided_ == ThreadMap::Iterations::kStrided) { - iteration_strided_ = 0; - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { - mask_.clear(); - } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { - mask_.enable(); - } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask &mask) { - mask = mask_; - } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const &mask) { - mask_ = mask; - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | InterleavedMaskedTileIterator | ForwardTileIterator -/// -template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) - typename Element_, ///< Element data type - int InterleavedN ///< Number of Interleaved N -> -class InterleavedConvPredicatedTileIteratorV2 { -public: - using ThreadMap = ThreadMap_; - - using Element = Element_; - - using Layout = layout::TensorNCxHWx; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = Tensor4DCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Iterations::kCount; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - struct Params { - - // - // Data members - // - - LongIndex stride_col; ///< stride in bytes between columns - LongIndex stride_row; ///< stride in bytes between rows - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Status initialize(typename Layout::Stride stride_) { - stride_col = stride_[1]; - stride_row = stride_[2]; - - return Status::kSuccess; - } - - CUTLASS_HOST_DEVICE - Params() { - initialize(cutlass::make_Coord(0, 0, 0)); - } - - CUTLASS_HOST_DEVICE - Params(Layout const &layout) { - - initialize(layout.stride()); - } - }; - - /// Mask object - struct Mask { - static int const kCount = - (ThreadMap::Iterations::kRow < 8) ? 8 : ThreadMap::Iterations::kRow; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { - enable(); - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - -private: - - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - Params params_; - - /// Byte-level pointer - uint8_t *byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in columns - Index extent_col_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in pq - Index extent_pq_; - - /// A thread's starting row position (assuming steady-state predicates have - /// been computed) - Index thread_start_row_; - - /// A thread's starting column position (assuming steady-state predicates have - /// been computed) - Index thread_start_col_; - - /// Internal iteration counter - LongIndex iteration_row_; - LongIndex iteration_col_; - - uint32_t pq_mul_; - - uint32_t pq_shr_; - -private: - - // - // Methods - // - -public: - - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - InterleavedConvPredicatedTileIteratorV2( - Params const & params, - Element *pointer, - TensorCoord extent, - int thread_idx, - MatrixCoord threadblock_offset - ): - params_(params) { - MatrixCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_col_ = extent.c(); - extent_pq_ = extent.h() * extent.w(); - extent_row_ = extent.n() * extent_pq_; - - find_divisor(pq_mul_, pq_shr_, extent_pq_); - - thread_start_row_ = thread_offset.row(); - thread_start_col_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int r = 0; r < ThreadMap::Iterations::kRow; ++r) { - mask_.predicates[r] = - ((thread_offset.row() + ThreadMap::Delta::kRow * r) < extent_row_); - } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - ((thread_start_col_ / InterleavedN) * params_.stride_col + - (thread_start_col_ % InterleavedN)) * - sizeof_bits::value / 8; - - // Initialize internal state counter - iteration_row_ = iteration_col_ = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment &frag) { - - int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; - bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); - bool guard = col_guard && mask_.predicates[iteration_row_]; - - int n, pq_rem; - - fast_divmod(n, pq_rem, - thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, - extent_pq_, pq_mul_, pq_shr_); - - uint8_t *byte_pointer = - byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * - sizeof_bits::value / 8; - AccessType *frag_ptr = reinterpret_cast(&frag); - AccessType const *memory_pointer = - reinterpret_cast(byte_pointer); - - cutlass::arch::global_load< - AccessType, - sizeof(AccessType) - >( - *frag_ptr, - (void *)memory_pointer, - guard); - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const &frag) { - - int col_offset = iteration_col_ * ThreadMap::Delta::kColumn; - bool col_guard = ((thread_start_col_ + col_offset) < extent_col_); - bool guard = col_guard && mask_.predicates[iteration_row_]; - - int n, pq_rem; - - fast_divmod(n, pq_rem, - thread_start_row_ + iteration_row_ * ThreadMap::Delta::kRow, - extent_pq_, pq_mul_, pq_shr_); - - uint8_t *byte_pointer = - byte_pointer_ + (n * params_.stride_row + pq_rem * InterleavedN) * - sizeof_bits::value / 8; - AccessType const *frag_ptr = reinterpret_cast(&frag); - AccessType *memory_pointer = reinterpret_cast(byte_pointer); - - cutlass::arch::global_store( - *frag_ptr, (void *)memory_pointer, guard); - } - - /// Overrides the internal iteration index - CUTLASS_HOST_DEVICE - void set_iteration_index(int iteration) { - iteration_row_ = iteration % ThreadMap::Iterations::kRow; - iteration_col_ = iteration / ThreadMap::Iterations::kRow; - } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - InterleavedConvPredicatedTileIteratorV2 &operator++() { - - ++iteration_row_; - - if (iteration_row_ == ThreadMap::Iterations::kRow) { - - iteration_row_ = 0; - ++iteration_col_; - byte_pointer_ += params_.stride_col; - - if (iteration_col_ == ThreadMap::Iterations::kColumn) { - iteration_col_ = 0; - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { - mask_.clear(); - } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { - mask_.enable(); - } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask &mask) { - mask = mask_; - } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const &mask) { - mask_ = mask; - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator.h b/include/cutlass/epilogue/threadblock/shared_load_iterator.h index b01b92fd..608a3b44 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator.h @@ -201,6 +201,11 @@ public: } } + /// Loads a fragment from memory + CUTLASS_DEVICE + void set_smem_base_address(Index address) { + } + /// Loads a fragment CUTLASS_DEVICE void load(Fragment &frag) const { diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h index d845d951..173ac103 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_mixed.h @@ -234,6 +234,10 @@ public: } } + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + /// Loads a fragment CUTLASS_DEVICE void load(Fragment &frag) const { @@ -395,6 +399,10 @@ public: } } + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + /// Loads a fragment CUTLASS_DEVICE void load(Fragment &frag) { @@ -556,6 +564,10 @@ public: } } + /// Set base smem address + CUTLASS_DEVICE + void set_smem_base_address(Index address) {} + /// Loads a fragment CUTLASS_DEVICE void load(Fragment &frag) { diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h new file mode 100644 index 00000000..47d0ec2e --- /dev/null +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + This assumes the shared memory tile is in a permuted layout which avoids bank conflicts on loading. + + When the fragment is loaded into registers, it matches the row-major thread map assumed by + the predicated tile iterator writing to global memory. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template ::value / 8> +class SharedLoadIteratorPitchLiner { + public: + using ThreadMap = ThreadMap_; + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kMinAlignment = + ThreadMap_::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kAlignment = (MaxAlignment < kMinAlignment ? MaxAlignment : kMinAlignment); + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = + AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + uint8_t *byte_pointer_; + + /// Stride along adjacent rows + int stride_; + + /// Base address offset + Index base_smem_address_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx) + : byte_pointer_(reinterpret_cast(ref.data())), + stride_((ref.stride(0) * sizeof_bits::value) / 8), + base_smem_address_(0) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointer + // thread_offset.row() is contiguous dim + // thread_offset.column() is stride dim + byte_pointer_ += thread_offset.row() * sizeof(AccessType) / kElementsPerAccess+ + thread_offset.column() * stride_ ; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &offset) { + byte_pointer_ += + offset.row() * ThreadMap::StorageShape::kContiguous * sizeof(AccessType) / kElementsPerAccess + + offset.column() * ThreadMap::StorageShape::kStrided * stride_; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + uint8_t const *byte_pointer = + byte_pointer_ + s * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous * ThreadMap::kElementsPerAccess * + sizeof_bits::value / 8 + + pointer_offset * sizeof_bits::value / 8 + base_smem_address_; + + int frag_base_idx = s * ThreadMap::Iterations::kContiguous + c; + + LoadType *frag_ptr = reinterpret_cast(&frag); + + LoadType const *memory_pointer = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + frag_ptr[frag_base_idx * kLoadsPerAccess + v] = memory_pointer[v]; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void set_smem_base_address(Index address) { base_smem_address_ = address; } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_simt.h b/include/cutlass/epilogue/warp/tile_iterator_simt.h index c2d80191..ecc75ea9 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_simt.h +++ b/include/cutlass/epilogue/warp/tile_iterator_simt.h @@ -240,12 +240,301 @@ public: void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Template for reading and writing tiles of accumulators to shared memory +template +class TileIteratorSimtDirectConv { + public: + + using WarpShape = WarpShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = layout::RowMajor; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = SimtPolicy; + + /// Shape of the tile in memory + using Shape = MatrixShape; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + /// Padding quantity + using Padding = MatrixShape<0, + 0 + >; + +private: + /// Storage type for accessing memory + using AccessType = AlignedArray< + Element, + Policy::kElementsPerAccess + >; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Base smem offset; + Index base_smem_address_; + + public: + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv() : pointer_(nullptr) {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements) { + + auto lane_layout = Policy::MmaSimtPolicy::get_lane_layout(); + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + pointer_ += layout_({ + lane_offset.row(), + lane_offset.column() * Policy::kElementsPerAccess / int(AccessType::kElements) + }); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & add_tile_offset(TensorCoord const &tile_offset) { + + pointer_ += layout_({ + tile_offset.row() * Shape::kRow, + (tile_offset.column() * Shape::kColumn / int(AccessType::kElements)) + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirectConv & operator+=(TensorCoord const &tile_offset) { + + add_tile_offset(tile_offset); + + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + // original vector stores + AccessType const *frag_ptr = reinterpret_cast(&frag); + AccessType * load_pointer_ = reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + load_pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)] = frag_ptr[n]; + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::kAccessesPerIteration; ++n) { + frag_ptr[n] = pointer_[n * Policy::MmaSimtPolicy::WarpShape::kColumn + pointer_offset / int(AccessType::kElements)]; + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address){ + base_smem_address_ = address; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Template for reading and writing tiles of accumulators to shared memory +template +class TileIteratorSimtDirect2dConv { + public: + using WarpShape = WarpShape_; + using ThreadOutputShape = ThreadOutputShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using Operator = Operator_; + using Element = Element_; + using Layout = layout::RowMajor; + using MmaSimtPolicy = MmaSimtPolicy_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + // Thread-level shape of a fragment + using ThreadShape = MatrixShape; + + static_assert(!(ThreadShape::kColumn % MmaSimtPolicy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + using ThreadTileCount = MatrixShape; + + using Iterations = + MatrixShape; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = typename Operator::FragmentC; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = AccumulatorTile; + + /// Padding quantity + using Padding = MatrixShape<0, 0>; + + private: + // Storage type for accessing memory + using AccessType = AlignedArray; + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Base smem offset; + Index base_smem_address_; + + public: + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv() : pointer_(nullptr) {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv(TensorRef const &ref, unsigned thread_id, unsigned lane_id) + : pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0] / AccessType::kElements) { + + auto lane_layout = MmaSimtPolicy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id); + + // Get base HW offset of current threads + const int threadgroup = thread_id / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + const int base_p = (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + const int base_q = (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + const int row_offset = base_p * ThreadBlockOutputShape::kW + base_q; + + pointer_ += layout_( + {row_offset, + lane_offset.column() * MmaSimtPolicy::LaneMmaShape::kN / int(AccessType::kElements)}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorSimtDirect2dConv &add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset / AccessType::kElements; + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType *storer_pointer_ = + reinterpret_cast(reinterpret_cast(pointer_) + base_smem_address_); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadOutputShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadOutputShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (w + h * ThreadBlockOutputShape::kW) * + (ThreadBlockOutputShape::kC / AccessType::kElements) + + col; + storer_pointer_[offset + pointer_offset / int(AccessType::kElements)] = + frag_ptr[w + h * ThreadOutputShape::kW + col]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { base_smem_address_ = address; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Template for reading and writing tiles of accumulators to shared memory template < typename WarpShape_, ///< shape of warp-level GEMM (concept: GemmShape) @@ -482,6 +771,10 @@ public: return add_tile_offset({1, 0}); } + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h index 24d94c75..2f14c475 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h @@ -228,6 +228,11 @@ public: TileIteratorTensorOp & operator++() { return add_tile_offset({1, 0}); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -420,6 +425,11 @@ public: TileIteratorTensorOp & operator++() { return add_tile_offset({0, 1}); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; @@ -645,6 +655,11 @@ public: TileIteratorTensorOpCanonical & operator++() { return add_tile_offset({1, 0}); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index eadb779a..68955341 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -304,6 +304,11 @@ public: void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -506,6 +511,11 @@ public: void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -697,6 +707,11 @@ public: void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index a5035b61..2c3d5124 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -242,6 +242,11 @@ public: void load(Fragment const &frag) { load_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -419,6 +424,11 @@ public: void load(Fragment const &frag) { load_with_pointer_offset(frag, 0); } + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h index 630ff73d..4daad9dc 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h @@ -207,6 +207,12 @@ public: void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + + /// Set smem base address + CUTLASS_HOST_DEVICE + void set_smem_base_address(Index address) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index aabfbdc3..7d1ecb39 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -280,6 +280,36 @@ struct FastDivmod { unsigned int multiplier; unsigned int shift_right; + /// Find quotient and remainder using device-side intrinsics + CUTLASS_HOST_DEVICE + void fast_divmod(int& quotient, int& remainder, int dividend) const { + +#if defined(__CUDA_ARCH__) + // Use IMUL.HI if divisor != 1, else simply copy the source. + quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; +#else + quotient = int((divisor != 1) ? int(((int64_t)dividend * multiplier) >> 32) >> shift_right : dividend); +#endif + + // The remainder. + remainder = dividend - (quotient * divisor); + } + + /// For long int input + CUTLASS_HOST_DEVICE + void fast_divmod(int& quotient, int64_t& remainder, int64_t dividend) const { + +#if defined(__CUDA_ARCH__) + // Use IMUL.HI if divisor != 1, else simply copy the source. + quotient = (divisor != 1) ? __umulhi(dividend, multiplier) >> shift_right : dividend; +#else + quotient = int((divisor != 1) ? ((dividend * multiplier) >> 32) >> shift_right : dividend); +#endif + // The remainder. + remainder = dividend - (quotient * divisor); + } + + /// Construct the FastDivmod object, in host code ideally. /// /// This precomputes some values based on the divisor and is computationally expensive. @@ -288,17 +318,35 @@ struct FastDivmod { FastDivmod(): divisor(0), multiplier(0), shift_right(0) { } CUTLASS_HOST_DEVICE - FastDivmod(int divisor_): divisor(divisor_) { - find_divisor(multiplier, shift_right, divisor); + FastDivmod(int divisor): divisor(divisor) { + + if (divisor != 1) { + unsigned int p = 31 + find_log2(divisor); + unsigned m = unsigned(((1ull << p) + unsigned(divisor) - 1) / unsigned(divisor)); + + multiplier = m; + shift_right = p - 32; + } else { + multiplier = 0; + shift_right = 0; + } } /// Computes integer division and modulus using precomputed values. This is computationally /// inexpensive. CUTLASS_HOST_DEVICE void operator()(int "ient, int &remainder, int dividend) const { - fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + fast_divmod(quotient, remainder, dividend); } + /// Computes integer division using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + int div(int dividend) const { + int quotient, remainder; + fast_divmod(quotient, remainder, dividend); + return quotient; + } /// Computes integer division and modulus using precomputed values. This is computationally /// inexpensive. @@ -307,7 +355,7 @@ struct FastDivmod { CUTLASS_HOST_DEVICE int divmod(int &remainder, int dividend) const { int quotient; - fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + fast_divmod(quotient, remainder, dividend); return quotient; } @@ -315,7 +363,7 @@ struct FastDivmod { /// inexpensive. CUTLASS_HOST_DEVICE void operator()(int "ient, int64_t &remainder, int64_t dividend) const { - fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + fast_divmod(quotient, remainder, dividend); } /// Computes integer division and modulus using precomputed values. This is computationally @@ -323,9 +371,14 @@ struct FastDivmod { CUTLASS_HOST_DEVICE int divmod(int64_t &remainder, int64_t dividend) const { int quotient; - fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + fast_divmod(quotient, remainder, dividend); return quotient; } + + /// Returns the divisor when cast to integer + CUTLASS_HOST_DEVICE + operator int() const { return divisor; } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h new file mode 100644 index 00000000..ac89c85c --- /dev/null +++ b/include/cutlass/float8.h @@ -0,0 +1,1213 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a class for using IEEE half-precision floating-point types in host or + device code. +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" + +#if defined(__CUDACC_RTC__) + +#include "cutlass/floating_point_nvrtc.h" + +#else +// +// Standard Library headers belong here to avoid conflicts with NVRTC. +// +#include +#include +#include +#include +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +#if (__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)) + +#ifndef CUDA_PTX_FP8_CVT_ENABLED +#define CUDA_PTX_FP8_CVT_ENABLED 1 +#endif + +#endif +#endif + + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// FP8 Has 2 encodings possible : E4M3 and E5M2 +// +// E4M3 : 7 | 6 5 4 3 | 2 1 0 +// E5M2 : 7 | 6 5 4 3 2 | 1 0 +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class FloatEncoding { + E4M3, + E5M2 +}; + +template +struct alignas(1) float8_base { + + static constexpr bool IS_E4M3 = (T == FloatEncoding::E4M3); + static constexpr bool IS_E5M2 = (T == FloatEncoding::E5M2); + + // Number of Bits representing mantissa and exponents + static constexpr int FP32_NUM_BITS = 32; + static constexpr int FP32_NUM_EXPONENT_BITS = 8; + static constexpr int FP32_NUM_MANTISSA_BITS = 23; + static constexpr uint32_t FP32_NAN = 0x7fffffff; + static constexpr uint32_t FP32_INFINITY_MASK = 0x7f800000; + static constexpr int FP32_MAX_EXPONENT = 127; + static constexpr int FP32_MIN_EXPONENT = -126; + static constexpr int FP32_EXPONENT_BIAS = 127; + + static constexpr int FP16_NUM_BITS = 16; + static constexpr int FP16_NUM_EXPONENT_BITS = 5; + static constexpr int FP16_NUM_MANTISSA_BITS = 10; + static constexpr uint16_t FP16_NAN = 0x7fff; + static constexpr uint16_t FP16_INFINITY_MASK = 0x7c00; + static constexpr int FP16_MAX_EXPONENT = 15; + static constexpr int FP16_MIN_EXPONENT = -14; + static constexpr int FP16_EXPONENT_BIAS = 15; + + static constexpr int FP8_NUM_BITS = 8; + static constexpr int FP8_NUM_EXPONENT_BITS = IS_E4M3 ? 4 : 5; + static constexpr int FP8_NUM_MANTISSA_BITS = IS_E4M3 ? 3 : 2; + static constexpr uint8_t FP8_NAN = 0x7f; // Also F8_INF + static constexpr uint8_t FP8_INFINITY_MASK = IS_E4M3 ? 0x78 : 0x7c; + static constexpr int FP8_MAX_EXPONENT = IS_E4M3 ? 7 : 15; + static constexpr int FP8_MIN_EXPONENT = IS_E4M3 ? -6 : -14; + static constexpr int FP8_EXPONENT_BIAS = IS_E4M3 ? 7 : 15; + + static constexpr uint8_t FP8_EXPONENT_MASK = (1 << FP8_NUM_EXPONENT_BITS) - 1; + static constexpr uint8_t FP8_MANTISSA_MASK = (1 << FP8_NUM_MANTISSA_BITS) - 1; + + static constexpr uint8_t FP8_MAX_FLT = (IS_E4M3 ? 0x7e : 0x7b); + + // 256 in float + static constexpr uint32_t FP8_SAT_VAL_FP32 = 0x43800000; + + // + // Data members + // + + /// Data container + uint8_t storage; + + /// Ctors. + CUTLASS_HOST_DEVICE + float8_base() : storage(0) { } + + /// Is finite implementation + CUTLASS_HOST_DEVICE + static bool isfinite(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + return (s & 0x7f800000) < 0x7f800000; + } + + /// Is NaN implementation + CUTLASS_HOST_DEVICE + static bool isnan(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + return (s & 0x7fffffff) > 0x7f800000; + } + + /// Is infinite implementation + CUTLASS_HOST_DEVICE + static bool isinf(float flt) { + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + // Sign = 0 for +inf, 1 for -inf + // Exponent = all ones + // Mantissa = all zeros + return (s == 0x7f800000) || (s == 0xff800000); + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static uint8_t convert_float_to_fp8(float const& flt) { + + // software implementation rounds toward nearest even + uint32_t s; + + #if defined(__CUDA_ARCH__) + s = reinterpret_cast(flt); + #else + std::memcpy(&s, &flt, sizeof(s)); + #endif + + // Extract the bits in the FP32 type + uint8_t sign = uint8_t((s >> 24 & 0x80)); + int8_t exp = uint8_t(((s >> FP32_NUM_MANTISSA_BITS) & 0xff) - FP32_EXPONENT_BIAS); + int mantissa = s & 0x7fffff; + uint8_t u = 0; + + uint8_t const kF8_NaN = 0x7f; + + // NaN => NaN + if (isnan(flt)) { + return kF8_NaN; + } + + // Inf => MAX_FLT (satfinite) + if (isinf(flt)) { + return sign | FP8_MAX_FLT; + } + + // Special handling + if ( exp == -128 ) { + // int8 range is from -128 to 127 + // So 255(inf) - 127(bias) = 128 - will show up as -128 + + // satfinite + return (sign | FP8_MAX_FLT); + } + + int sticky_bit = 0; + + bool skip_sign = false; + bool may_be_nan = false; + + if ( (exp >= FP8_MIN_EXPONENT) && (exp <= FP8_MAX_EXPONENT) ) { + // normal fp32 to normal fp8 + exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); + u = uint8_t(((exp & FP8_EXPONENT_MASK) << FP8_NUM_MANTISSA_BITS)); + u = uint8_t(u | (mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS))); + } else if(exp < FP8_MIN_EXPONENT) { + // normal single-precision to subnormal float8-precision representation + int rshift = (FP8_MIN_EXPONENT - exp); + if (rshift < FP32_NUM_BITS) { + mantissa |= (1 << FP32_NUM_MANTISSA_BITS); + + sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0); + + mantissa = (mantissa >> rshift); + u = (uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS- FP8_NUM_MANTISSA_BITS)) & FP8_MANTISSA_MASK); + } else { + mantissa = 0; + u = 0; + } + // Exponent > FP8_MAX_EXPONENT - this is a special case done to match HW + // 0x4380_0000 to 0x43e0_0000 - maps from 256 to 448, and does not saturate / inf. + } else { + if( exp == (FP8_MAX_EXPONENT + 1) ) { + uint8_t mantissa_tmp = uint8_t(mantissa >> (FP32_NUM_MANTISSA_BITS - FP8_NUM_MANTISSA_BITS)); + if( mantissa_tmp < FP8_MANTISSA_MASK) { + exp = uint8_t(exp + uint8_t(FP8_EXPONENT_BIAS)); + u = uint8_t(exp << FP8_NUM_MANTISSA_BITS) | mantissa_tmp; + may_be_nan = (mantissa_tmp == (FP8_MANTISSA_MASK-1)); + } else { + // satfinite + return (sign | FP8_MAX_FLT); + } + } else{ + // satfinite + return (sign | FP8_MAX_FLT); + } + } + + // round to nearest even + int NUM_BITS_SHIFT = FP32_NUM_MANTISSA_BITS - (FP8_NUM_MANTISSA_BITS + 1); + int round_bit = ((mantissa >> NUM_BITS_SHIFT) & 1); + sticky_bit |= ((mantissa & ((1 << NUM_BITS_SHIFT) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && (u & 1))) { + u = uint8_t(u + 1); + if( may_be_nan ) { + skip_sign = true; + } + } + + if (u > FP8_MAX_FLT) { + // satfinite + u = (sign | FP8_MAX_FLT); + } + + if( ! skip_sign ) { + u |= sign; + } + + return u; + } + + + /// Converts a fp8 value stored as a uint8_t to a float + CUTLASS_HOST_DEVICE + static float convert_fp8_to_float(uint8_t const& x) { + + uint32_t constexpr kF32_NaN = 0x7fffffff; + + uint8_t const &f8 = x; + int sign = (f8 >> (FP8_NUM_BITS - 1)) & 1; + int exp = (f8 >> FP8_NUM_MANTISSA_BITS) & FP8_EXPONENT_MASK; + int mantissa = f8 & FP8_MANTISSA_MASK; + unsigned f = (sign << (FP32_NUM_BITS-1)); + + if (IS_E4M3 && exp == 15 && mantissa == 0x7) { + f = kF32_NaN; + } + else if (exp > 0 && (IS_E4M3 || exp < (FP8_MAX_EXPONENT + FP8_EXPONENT_BIAS + 1))) { + // normal + exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS); + f = f | + (exp << FP32_NUM_MANTISSA_BITS) | + (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); + } else if (exp == 0) { + if (mantissa) { + // subnormal + exp += (FP32_EXPONENT_BIAS - FP8_EXPONENT_BIAS) + 1; + while ((mantissa & (1 << FP8_NUM_MANTISSA_BITS)) == 0) { + mantissa <<= 1; + exp--; + } + mantissa &= FP8_MANTISSA_MASK; + f = f | + (exp << FP32_NUM_MANTISSA_BITS) | + (mantissa << (FP32_NUM_MANTISSA_BITS-FP8_NUM_MANTISSA_BITS)); + } else { + // sign-preserving zero + } + } else { + if(mantissa == 0){ + // Sign-preserving infinity + f = (f | 0x7f800000); + } else { + // Canonical NaN + f = kF32_NaN; + } + } + + #if defined(__CUDA_ARCH__) + return reinterpret_cast(f); + #else + float flt; + std::memcpy(&flt, &f, sizeof(flt)); + return flt; + #endif + } +}; + + +// Forward declaration of float_e5m2_t to define float_e4m3_t <=> float_e5m2_t +// conversions in class float_e4m3_t +struct float_e5m2_t; + + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : E4M3 +/// +/////////////////////////////////////////////////////////////// +struct alignas(1) float_e4m3_t : float8_base { + + using Base = float8_base; + + static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; + + // + // Static conversion operators + // + + /// Constructs from an uint8_t + CUTLASS_HOST_DEVICE + static float_e4m3_t bitcast(uint8_t x) { + float_e4m3_t f; + f.storage = x; + return f; + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e4m3_t from_float(float const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(flt)); + #endif + } + + /// FP16 -> E5M2 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e4m3_t from_half(half const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp = 0; + uint32_t bits = reinterpret_cast(flt); + asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(float(flt))); + #endif + } + + // E4M3 -> half + CUTLASS_HOST_DEVICE + static half to_half(float_e4m3_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return reinterpret_cast(packed).x; + #else + return half(Base::convert_fp8_to_float(x.storage)); + #endif + } + + // E4M3 -> Float + CUTLASS_HOST_DEVICE + static float to_float(float_e4m3_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return float(reinterpret_cast(packed).x); + #else + return Base::convert_fp8_to_float(x.storage); + #endif + } + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + float_e4m3_t() : Base() { } + + /// Reinterpret cast from CUDA's FP8 type + CUTLASS_HOST_DEVICE + float_e4m3_t(float_e4m3_t const& x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + uint8_t raw = x.storage; + std::memcpy(&storage, &raw, sizeof(storage)); + #endif + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(float x) { + storage = from_float(x).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(half x) { + storage = from_half(x).storage; + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(double x): float_e4m3_t(float(x)) { + } + + /// Integer conversion + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(int x): float_e4m3_t(float(x)) { + } + + /// E5M2 conversion. Defined after float_e5m2_t is defined. + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(float_e5m2_t x); + + /// Assignment + CUTLASS_HOST_DEVICE + float_e4m3_t & operator=(float_e4m3_t const &x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + uint8_t raw = x.storage; + std::memcpy(&storage, &raw, sizeof(storage)); + #endif + return *this; + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return to_float(*this); + } + + /// Converts to half + CUTLASS_HOST_DEVICE + operator half() const { + return to_half(*this); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(to_float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + #if defined(__CUDA_ARCH__) + return __half2int_rn(to_half(*this)); + #else + return int(to_float(*this)); + #endif + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + #if defined(__CUDA_ARCH__) + return bool(__half2int_rn(to_half(*this))); + #else + return bool(int(to_float(*this))); + #endif + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t& raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 15; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(storage & Base::FP8_MANTISSA_MASK); + } +}; + +/////////////////////////////////////////////////////////////// +/// +/// floating-point 8 type : E5M2 +/// +/////////////////////////////////////////////////////////////// +struct alignas(1) float_e5m2_t : float8_base { + + using Base = float8_base; + + static constexpr int MAX_EXPONENT = Base::FP8_MAX_EXPONENT; + + // + // Static conversion operators + // + + /// Constructs from an uint8_t + CUTLASS_HOST_DEVICE + static float_e5m2_t bitcast(uint8_t x) { + float_e5m2_t f; + f.storage = x; + return f; + } + + /// FP32 -> FP8 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e5m2_t from_float(float const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp; + float y = float(); + asm volatile("cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;" : "=h"(tmp) : "f"(y), "f"(flt)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(flt)); + #endif + } + + /// FP16 -> E5M2 conversion - rounds to nearest even + CUTLASS_HOST_DEVICE + static float_e5m2_t from_half(half const& flt) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t tmp = 0; + uint32_t bits = reinterpret_cast(flt); + asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(tmp) : "r"(bits)); + + return *reinterpret_cast(&tmp); + #else + return bitcast(Base::convert_float_to_fp8(float(flt))); + #endif + } + + // E5M2 -> half + CUTLASS_HOST_DEVICE + static half to_half(float_e5m2_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return reinterpret_cast(packed).x; + #else + return half(Base::convert_fp8_to_float(x.storage)); + #endif + } + + // E5M2 -> Float + CUTLASS_HOST_DEVICE + static float to_float(float_e5m2_t const& x) { + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t bits = x.storage; + uint32_t packed; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;\n" : "=r"(packed) : "h"(bits)); + + return float(reinterpret_cast(packed).x); + #else + return Base::convert_fp8_to_float(x.storage); + #endif + } + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + float_e5m2_t() : Base() { } + + /// Reinterpret cast from CUDA's FP8 type + CUTLASS_HOST_DEVICE + float_e5m2_t(float_e5m2_t const& x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + uint8_t raw = x.storage; + std::memcpy(&storage, &raw, sizeof(storage)); + #endif + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(float x) { + storage = from_float(x).storage; + } + + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(half x) { + storage = from_half(x).storage; + } + + /// Floating point conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(double x): float_e5m2_t(float(x)) { + } + + /// Integer conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(int x): float_e5m2_t(float(x)) { + } + + /// E4M3 conversion + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(float_e4m3_t x); + + /// Assignment + CUTLASS_HOST_DEVICE + float_e5m2_t & operator=(float_e5m2_t const &x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + uint8_t raw = x.storage; + std::memcpy(&storage, &raw, sizeof(storage)); + #endif + return *this; + } + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + return to_float(*this); + } + + /// Converts to half + CUTLASS_HOST_DEVICE + operator half() const { + return to_half(*this); + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(to_float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + #if defined(__CUDA_ARCH__) + return __half2int_rn(to_half(*this)); + #else + return int(to_float(*this)); + #endif + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + #if defined(__CUDA_ARCH__) + return bool(__half2int_rn(to_half(*this))); + #else + return bool(int(to_float(*this))); + #endif + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t& raw() { + return storage; + } + + /// Accesses raw internal state + CUTLASS_HOST_DEVICE + uint8_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((storage & (1 << (Base::FP8_NUM_BITS - 1))) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((storage >> FP8_NUM_MANTISSA_BITS) & Base::FP8_EXPONENT_MASK); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 15; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(storage & Base::FP8_MANTISSA_MASK); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator+(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator-(float_e4m3_t const& lhs) { + return float_e4m3_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator-(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator*(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator/(float_e4m3_t const& lhs, float_e4m3_t const& rhs) { + return float_e4m3_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator+=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator-=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator*=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator/=(float_e4m3_t & lhs, float_e4m3_t const& rhs) { + lhs = float_e4m3_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator++(float_e4m3_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = float_e4m3_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t& operator--(float_e4m3_t & lhs) { + float tmp(lhs); + --tmp; + lhs = float_e4m3_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator++(float_e4m3_t & lhs, int) { + float_e4m3_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_e4m3_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +float_e4m3_t operator--(float_e4m3_t & lhs, int) { + float_e4m3_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_e4m3_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +bool operator==(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator+(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) + float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator-(float_e5m2_t const& lhs) { + return float_e5m2_t(-float(lhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator-(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator*(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator/(float_e5m2_t const& lhs, float_e5m2_t const& rhs) { + return float_e5m2_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator+=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator-=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator*=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator/=(float_e5m2_t & lhs, float_e5m2_t const& rhs) { + lhs = float_e5m2_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator++(float_e5m2_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = float_e5m2_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t& operator--(float_e5m2_t & lhs) { + float tmp(lhs); + --tmp; + lhs = float_e5m2_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator++(float_e5m2_t & lhs, int) { + float_e5m2_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = float_e5m2_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +float_e5m2_t operator--(float_e5m2_t & lhs, int) { + float_e5m2_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = float_e5m2_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// float_e4m3_t <=> float_e5m2_t conversions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// float_e4m3_t <= float_e5m2_t +CUTLASS_HOST_DEVICE +float_e4m3_t::float_e4m3_t(float_e5m2_t x) { + storage = from_float(float_e5m2_t::to_float(x)).storage; +} + +/// float_e5m2_t <= float_e4m3_t +CUTLASS_HOST_DEVICE +float_e5m2_t::float_e5m2_t(float_e4m3_t x) { + storage = from_float(float_e4m3_t::to_float(x)).storage; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +namespace std { + +/// Numeric limits common to all float8 types +template +struct float8_base_numeric_limits { +private: + using F8Type = T; +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; + + /// Least positive value + static F8Type min() { return F8Type::bitcast(0x01); } + + /// Maximum finite value + static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } + + /// Returns maximum rounding error + static F8Type round_error() { return F8Type(0.5f); } + + /// Returns positive infinity value + static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } + + /// Returns quiet NaN value + static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns signaling NaN value + static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns smallest positive subnormal value + static F8Type denorm_min() { return F8Type::bitcast(0x01); } +}; + +/// Numeric limits for float_e4m3_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = false; + + /// Minimum finite value + static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } + + /// Returns smallest finite value + static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } +}; + +/// Numeric limits for float_e5m2_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = true; + + /// Minimum finite value + static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } + + /// Returns smallest finite value + static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } +}; + +} // namespace std +#endif + +namespace platform { + +/// Numeric limits common to all float8 types +template +struct float8_base_numeric_limits { +private: + using F8Type = T; +public: + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = F8Type::FP8_NUM_MANTISSA_BITS; + + /// Least positive value + static F8Type min() { return F8Type::bitcast(0x01); } + + /// Maximum finite value + static F8Type max() { return F8Type::bitcast(F8Type::FP8_MAX_FLT); } + + /// Returns maximum rounding error + static F8Type round_error() { return F8Type(0.5f); } + + /// Returns positive infinity value + static F8Type infinity() { return F8Type::bitcast(F8Type::FP8_INFINITY_MASK); } + + /// Returns quiet NaN value + static F8Type quiet_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns signaling NaN value + static F8Type signaling_NaN() { return F8Type::bitcast(F8Type::FP8_NAN); } + + /// Returns smallest positive subnormal value + static F8Type denorm_min() { return F8Type::bitcast(0x01); } +}; + +/// std::numeric_limits +template +struct numeric_limits; + +/// Numeric limits for float_e4m3_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = false; + + /// Minimum finite value + static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } + + /// Returns smallest finite value + static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } +}; + +/// Numeric limits for float_e5m2_t +template <> +struct numeric_limits : + public float8_base_numeric_limits { + static bool const has_infinity = true; + + /// Minimum finite value + static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } + + /// Returns smallest finite value + static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } +}; + +} // namespace platform + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::float_e4m3_t operator "" _fe4m3(long double x) { + return cutlass::float_e4m3_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e4m3_t operator "" _fe4m3(unsigned long long int x) { + return cutlass::float_e4m3_t(int(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e5m2_t operator "" _fe5m2(long double x) { + return cutlass::float_e5m2_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::float_e5m2_t operator "" _fe5m2(unsigned long long int x) { + return cutlass::float_e5m2_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/floating_point_nvrtc.h b/include/cutlass/floating_point_nvrtc.h new file mode 100644 index 00000000..9109e6a8 --- /dev/null +++ b/include/cutlass/floating_point_nvrtc.h @@ -0,0 +1,65 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines categories for floating point numbers for use in NVRTC-compiled code +*/ + +#pragma once + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// All floating-point numbers can be put in one of these categories. +enum { + FP_NAN = +# define FP_NAN 0 + FP_NAN, + FP_INFINITE = +# define FP_INFINITE 1 + FP_INFINITE, + FP_ZERO = +# define FP_ZERO 2 + FP_ZERO, + FP_SUBNORMAL = +# define FP_SUBNORMAL 3 + FP_SUBNORMAL, + FP_NORMAL = +# define FP_NORMAL 4 + FP_NORMAL +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 582a8e30..d352fe0c 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Define basic numeric operators with specializations for Array. SIMD-ize where possible. + \brief Define basic numeric operators This is inspired by the Standard Library's header. */ @@ -38,11 +38,12 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" -#include "cutlass/complex.h" -#include "cutlass/quaternion.h" -#include "cutlass/array.h" #include "cutlass/half.h" +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -82,15 +83,6 @@ struct multiplies { } }; -template -struct multiplies> { - CUTLASS_HOST_DEVICE - Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { - lhs = lhs * rhs; - return lhs; - } -}; - /// Squares with optional conversion template struct square { @@ -115,37 +107,6 @@ struct magnitude_squared { } }; -/// Squares with optional conversion -template -struct magnitude_squared, Output> { - CUTLASS_HOST_DEVICE - Output operator()(complex lhs) const { - multiplies mul_op; - - Output y_r = Output(lhs.real()); - Output y_i = Output(lhs.imag()); - - return mul_op(y_r, y_r) + mul_op(y_i, y_i); - } -}; - -/// Squares with optional conversion -template -struct magnitude_squared, Output> { - CUTLASS_HOST_DEVICE - Output operator()(Quaternion lhs) const { - multiplies mul_op; - - Output y_w = Output(lhs.w()); - Output y_x = Output(lhs.x()); - Output y_y = Output(lhs.y()); - Output y_z = Output(lhs.z()); - - return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ - mul_op(y_z, y_z); - } -}; - /// Computes the square of a difference with optional conversion template struct square_difference { @@ -170,20 +131,7 @@ struct magnitude_squared_difference { } }; -/// Computes the square of a difference with optional conversion -template -struct magnitude_squared_difference, Output> { - CUTLASS_HOST_DEVICE - Output operator()(complex lhs, complex rhs) const { - multiplies mul_op; - - Output y_r = Output(lhs.real()) - Output(rhs.real()); - Output y_i = Output(lhs.imag()) - Output(rhs.imag()); - - return mul_op(y_r, y_r) + mul_op(y_i, y_i); - } -}; - +/// Divides template struct divides { CUTLASS_HOST_DEVICE @@ -193,7 +141,7 @@ struct divides { } }; - +/// Negate template struct negate { CUTLASS_HOST_DEVICE @@ -378,1992 +326,128 @@ struct bit_xor { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// Partial specializations for Arrays -template -struct bit_and> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] & b_data[i]); - } - - return result; - } -}; - -// Partial specializations for Arrays -template -struct bit_or> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] | b_data[i]); - } - - return result; - } -}; - -// Partial specializations for Arrays -template -struct bit_not> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (~a_data[i]); - } - - return result; - } -}; - -// Partial specializations for Arrays -template -struct bit_xor> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b) const { - using ArrayType = Array; - using Storage = typename ArrayType::Storage; - ArrayType result; - - Storage *result_data = result.raw_data(); - Storage const *a_data = a.raw_data(); - Storage const *b_data = b.raw_data(); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ArrayType::kStorageElements; ++i) { - result_data[i] = (a_data[i] ^ b_data[i]); - } - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// +/// Reduces value into the data pointed to by ptr template -struct conjugate> { - CUTLASS_HOST_DEVICE - complex operator()(complex const &a) const { - return conj(a); +struct red +{ + CUTLASS_DEVICE + void operator()(T *ptr, const T &data) + { + atomicAdd(ptr, data); } }; -template -struct conjugate > { - CUTLASS_HOST_DEVICE - Array operator()(Array const &a) const { - conjugate conj_op; +/// Reduces value into the data pointed to by ptr (double specialization) +template<> +struct red +{ + CUTLASS_DEVICE + void operator()(double *ptr, const double &data) + { +#if !defined(__CUDA_ARCH__) +#elif (__CUDA_ARCH__ >= 600) - Array ca; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - ca[i] = conj_op(a[i]); - } - return ca; + atomicAdd(ptr, data); + +#else + + // Use CAS loop + unsigned long long int* ptr_int = reinterpret_cast(ptr); + unsigned long long int old_int = *ptr_int; + unsigned long long int assumed_int; + + do { + double update = data + __longlong_as_double(old_int); + assumed_int = old_int; + old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update)); + } while (assumed_int != old_int); + +#endif // (__CUDA_ARCH__ >= 600) } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specialization for complex to target four scalar fused multiply-adds. -// -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Fused multiply-add -template -struct multiply_add, complex, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - complex const &a, - complex const &b, - complex const &c) const { +/// Reduces value into the data pointed to by ptr (half2 specialization) +template<> +struct red +{ + CUTLASS_DEVICE + void operator()(half2 *ptr, const half2 &data) + { +#if !defined(__CUDA_ARCH__) +#elif (__CUDA_ARCH__ >= 600) - T real = c.real(); - T imag = c.imag(); + // Vector-2 atomic reduction requires .target sm_60 or higher + uint32_t word = reinterpret_cast(data); + asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word)); - real += a.real() * b.real(); - real += -a.imag() * b.imag(); - imag += a.real() * b.imag(); - imag += a.imag () * b.real(); +#else - return complex{ - real, - imag - }; + // Use CAS loop + uint32_t *ptr_int = reinterpret_cast(ptr); + uint32_t old_int = *ptr_int; + uint32_t assumed_int; + + do + { + half2 old = reinterpret_cast(old_int); + + half hi = __hadd(__high2half(old), __high2half(data)); + half lo = __hadd(__low2half(old), __low2half(data)); + half2 update = __halves2half2(hi, lo); + uint32_t update_int = reinterpret_cast(update); + + assumed_int = old_int; + old_int = atomicCAS(ptr_int, assumed_int, update_int); + + } while (assumed_int != old_int); + +#endif // (__CUDA_ARCH__ >= 600) } }; -/// Fused multiply-add -template -struct multiply_add, T, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - complex const &a, - T const &b, - complex const &c) const { - - T real = c.real(); - T imag = c.imag(); - - real += a.real() * b; - imag += a.imag () * b; - - return complex{ - real, - imag - }; - } -}; - -/// Fused multiply-add -template -struct multiply_add, complex> { - CUTLASS_HOST_DEVICE - complex operator()( - T const &a, - complex const &b, - complex const &c) const { - - T real = c.real(); - T imag = c.imag(); - - real += a * b.real(); - imag += a * b.imag(); - - return complex{ - real, - imag - }; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct absolute_value_op< Array > { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - absolute_value_op scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -template -struct plus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - plus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; -template -struct minus> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - minus scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct multiplies> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - multiplies scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct divides> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - divides scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct maximum> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - maximum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct minimum> { - - CUTLASS_HOST_DEVICE - static T scalar_op(T const &lhs, T const &rhs) { - return (rhs < lhs ? rhs : lhs); - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], rhs[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs, T const &scalar) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i], scalar); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { - - Array result; - minimum scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, rhs[i]); - } - - return result; - } -}; - -template -struct negate> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &lhs) const { - - Array result; - negate scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(lhs[i]); - } - - return result; - } -}; - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], b[i], c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(a[i], scalar, c[i]); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = scalar_op(scalar, b[i], c[i]); - } - - return result; - } -}; - -/// Fused multiply-add-relu0 -template -struct multiply_add_relu0, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const &a, T const &scalar, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); - } - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(T const &scalar, Array const &b, Array const &c) const { - - Array result; - multiply_add scalar_op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); - } - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for Array targeting SIMD instructions in device code. -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct plus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] + rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs + rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] + rhs; - } - #endif - - return result; - } -}; - -template -struct minus> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] - rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs - rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] - rhs; - } - #endif - - return result; - } -}; - -template -struct multiplies> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] * rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmul( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs * rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hmul( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] * rhs; - } - #endif - - return result; - } -}; - -template -struct divides> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hdiv( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] / rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hdiv( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs / rhs[i]; - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hdiv( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = lhs[i] / rhs; - } - #endif - - return result; - } -}; - -template -struct negate> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hneg2(source_ptr[i]); - } - - if (N % 2) { - half_t x = lhs[N - 1]; - __half lhs_val = -reinterpret_cast<__half const &>(x); - result[N - 1] = reinterpret_cast(lhs_val); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = -lhs[i]; - } - #endif - - return result; - } -}; - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - half_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); - } - - if (N % 2) { - - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - __half d_residual = __hfma( - reinterpret_cast<__half const &>(a), - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a, b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - half_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(b), - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - half_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - - __half d_residual = __hfma( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - reinterpret_cast<__half const &>(c)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c); - } - #endif - - return result; - } -}; - -/// Fused multiply-add-relu0 -template -struct multiply_add_relu0, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - half_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); - } - - if (N % 2) { - - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - __half d_residual = __hfma_relu( - reinterpret_cast<__half const &>(a), - b_residual_ptr[N - 1], - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a, b[i], c[i]), half_t(0)); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - half_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); - __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(b), - c_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b, c[i]), half_t(0)); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - half_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); - __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); - __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); - } - - if (N % 2) { - - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); - - __half d_residual = __hfma_relu( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1], - reinterpret_cast<__half const &>(c)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - multiply_add op; - maximum mx; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = mx(op(a[i], b[i], c), half_t(0)); - } - #endif - - return result; - } -}; - -template -struct minimum> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmin( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs[i] ? rhs[i] : lhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmin( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (rhs[i] < lhs ? rhs[i] : lhs); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hmin( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (rhs < lhs[i] ? rhs : lhs[i]); - } - #endif - - return result; - } -}; - -template -struct maximum> { - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmax( - a_residual_ptr[N - 1], - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs[i] ? rhs[i] : lhs[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(half_t const & lhs, Array const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); - __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); - } - - if (N % 2) { - __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); - - __half d_residual = __hmax( - reinterpret_cast<__half const &>(lhs), - b_residual_ptr[N - 1]); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (lhs < rhs[i] ? rhs[i] : lhs); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()(Array const & lhs, half_t const &rhs) const { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); - __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); - __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); - } - - if (N % 2) { - __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); - - __half d_residual = __hmax( - a_residual_ptr[N - 1], - reinterpret_cast<__half const &>(rhs)); - - result[N - 1] = reinterpret_cast(d_residual); - } - - #else - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = (lhs[i] < rhs ? rhs : lhs[i]); - } - #endif - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Fused multiply-add -template -struct multiply_add, Array, Array> { - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *b_ptr = reinterpret_cast(&b); - unsigned const *c_ptr = reinterpret_cast(&c); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) - ); - } - - if (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - bfloat16_t const &a, - Array const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *b_ptr = reinterpret_cast(&b); - unsigned const *c_ptr = reinterpret_cast(&c); - - unsigned a_packed = static_cast(a.raw()); - a_packed = (a_packed | (a_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) - ); - } - - if (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a, b[i], c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - bfloat16_t const &b, - Array const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *c_ptr = reinterpret_cast(&c); - - unsigned b_packed = static_cast(b.raw()); - b_packed = (b_packed | (b_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) - ); - } - - if (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b, c[i]); - } - #endif - - return result; - } - - CUTLASS_HOST_DEVICE - Array operator()( - Array const &a, - Array const &b, - bfloat16_t const &c) const { - - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - - unsigned *result_ptr = reinterpret_cast(&result); - - unsigned const *a_ptr = reinterpret_cast(&a); - unsigned const *b_ptr = reinterpret_cast(&b); - - unsigned c_packed = static_cast(c.raw()); - c_packed = (c_packed | (c_packed << 16)); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 2; ++i) { - asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(result_ptr[i]) - : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) - ); - } - - if (N % 2) { - - uint16_t *result_ptr = reinterpret_cast(&result); - uint16_t const *a_residual_ptr = reinterpret_cast(&a); - uint16_t const *b_residual_ptr = reinterpret_cast(&b); - uint16_t const *c_residual_ptr = reinterpret_cast(&c); - - asm ("fma.rn.bf16 %0, %1, %2, %3;\n" - : "=h"(result_ptr[N - 1]) - : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) - ); - } - - #else - - multiply_add op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - result[i] = op(a[i], b[i], c); - } - #endif - - return result; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - - -template -CUTLASS_HOST_DEVICE -Array operator+(Array const &lhs, Array const &rhs) { - plus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator-(Array const &lhs, Array const &rhs) { - minus> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator-(Array const &lhs) { - negate> op; - return op(lhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(Array const &lhs, Array const &rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(T lhs, Array const &rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator*(Array const &lhs, T rhs) { - multiplies> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array operator/(Array const &lhs, Array const &rhs) { - divides> op; - return op(lhs, rhs); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, Array const &b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(T a, Array const &b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, T b, Array const &c) { - multiply_add> op; - return op(a, b, c); -} - -template -CUTLASS_HOST_DEVICE -Array fma(Array const &a, Array const &b, T c) { - multiply_add> op; - return op(a, b, c); -} - ///////////////////////////////////////////////////////////////////////////////////////////////// // -// Partial specializations for Quaternion fused multiply-add +// Partial specializations for nvcuda::wmma::fragment // ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct multiply_add, Quaternion, Quaternion> { +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +template +struct plus> +{ + using Fragment = nvcuda::wmma::fragment; + using ElementType = typename Fragment::element_type; + CUTLASS_HOST_DEVICE - Quaternion operator()( - Quaternion const &a, - Quaternion const &b, - Quaternion const &c) const { + Fragment operator()(Fragment const &lhs, Fragment const &rhs) const + { + Fragment result; + plus scalar_op; - T x = c.x(); - T y = c.y(); - T z = c.z(); - T w = c.w(); + ElementType *result_elts = reinterpret_cast(&result); + const ElementType *lhs_elts = reinterpret_cast(&lhs); + const ElementType *rhs_elts = reinterpret_cast(&rhs); - x += a.w() * b.x(); - x += b.w() * a.x(); - x += a.y() * b.z(); - x += -a.z() * b.y(), - - y += a.w() * b.y(); - y += b.w() * a.y(); - y += a.z() * b.x(); - y += -a.x() * b.z(); - - z += a.w() * b.z(); - z += b.w() * a.z(); - z += a.x() * b.y(); - z += -a.y() * b.x(); - - w += a.w() * b.w(); - w += -a.x() * b.x(); - w += -a.y() * b.y(); - w += -a.z() * b.z(); - - return cutlass::make_Quaternion(x, y, z, w); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Fragment::num_elements; i++) { + result_elts[i] = scalar_op(lhs_elts[i], rhs_elts[i]); + } + return result; } }; +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h index a1ef08eb..d14ee604 100644 --- a/include/cutlass/gemm/device/base_grouped.h +++ b/include/cutlass/gemm/device/base_grouped.h @@ -211,7 +211,7 @@ public: /// Computes the maximum number of active blocks per multiprocessor static int maximum_active_blocks(int smem_capacity = -1) { - CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); @@ -349,7 +349,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + CUTLASS_TRACE_HOST("BaseGrouped::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); // Workspace diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index 55592d2b..303a5cd7 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -765,6 +765,52 @@ struct DefaultGemmConfiguration< //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemmConfiguration { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<128, 256, 64>; + using WarpShape = GemmShape<64, 64, 64>; + using InstructionShape = GemmShape<16, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + ElementC, 128 / sizeof_bits::value, ElementAccumulator, + ElementAccumulator>; + + using Operator = arch::OpMultiplyAdd; +}; + +template <> +struct DefaultGemmConfiguration< + arch::OpClassTensorOp, + arch::Sm90, + complex, + complex, + complex, + complex + > { + + static int const kAlignmentA = 1; + static int const kAlignmentB = 1; + + using ThreadblockShape = GemmShape<64, 64, 16>; + using WarpShape = GemmShape<32, 32, 16>; + using InstructionShape = GemmShape<16, 8, 4>; + static int const kStages = 3; + + using EpilogueOutputOp = epilogue::thread::LinearCombination< + complex, 1, complex, + complex>; + + using Operator = arch::OpMultiplyAddComplex; +}; + } // namespace device } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/device/ell_gemm.h b/include/cutlass/gemm/device/ell_gemm.h new file mode 100644 index 00000000..834be95a --- /dev/null +++ b/include/cutlass/gemm/device/ell_gemm.h @@ -0,0 +1,848 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a Block-Ell sparse gemm kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/ell_gemm.h" + +#include "cutlass/gemm/kernel/default_ell_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Blocked-Ell sparse gemm device-level operator. This is an interface to efficient CUTLASS + Blocked-Ell kernels that may be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to Blocked-Ell problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + Example of a CUTLASS EllGemm operator is as follows: + + // + // Instantiate the CUTLASS EllGemm operator. + // + + cutlass::gemm::device::EllGemm< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + cutlass::half_t, + cutlass::layout::ColumnMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + cutlass::half_t, 128 / cutlass::sizeof_bits::value, + float, float>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, // Stages + 128 / cutlass::sizeof_bits::value, // Alignment A + 128 / cutlass::sizeof_bits::value // Alignment B + > ellgemm_op; + + // + // Launch the EllGemm operation on the device + // + + Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format: + a_rows - Rows in the sparse matrix. + a_cols - Colums in the sparse matrix. + BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in + consecutive blocks, whose size is (a_rows * a_ell_num_columns) + ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is + (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) + a_ell_blocksize - Size of the ELL-Blocks. + a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) + B - Input dense matrix whose size is (a_cols * n) + C/D - Output dense matrix whose size is (a_rows * n) + + cutlass::Status status = ellgemm_op({ + {a_rows, n, a_cols}, // GemmCoord problem_size + {BlockedEllA, lda}, // TensorRef ref_BlockedEllA + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + ell_idx, // Blocked-ELL Column indices or ellColInd matrix (const int*) + a_ell_num_columns, // Columns in the Blocked-Ellpack (ellValue) matrix (int) + a_ell_blocksize, // Size of the ELL-Blocks (int) + a_ell_base, // Base index of ellColInd (int) - Zero or One + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + + /// Access granularity of A matrix in units of elements + int AlignmentA, + + /// Access granularity of B matrix in units of elements + int AlignmentB, + + /// Supports split-K with serial reduction + bool SplitKSerial, + + /// Operation performed by GEMM + typename Operator, + + /// Sparse matrix is A or not + bool IsASparse + > + class EllGemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassTensorOp, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Sparse matrix is A or not + bool IsASparse = true + > +class EllGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kIsASparse = IsASparse; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultEllGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kIsASparse + >::GemmKernel; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + const int* ell_idx_, + int ell_ncol_, + int ell_blocksize_, + int ell_base_idx_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ell_idx(ell_idx_), + ell_ncol(ell_ncol_), + ell_blocksize(ell_blocksize_), + ell_base_idx(ell_base_idx_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + EllGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {args.ell_blocksize, + ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + tiled_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ell_idx, + args.ell_ncol, + args.ell_blocksize, + args.ell_base_idx, + args.epilogue, + static_cast(workspace) + }; + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {args.ell_blocksize, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + grid_shape.m() *= (args.ell_blocksize + ThreadblockShape::kM - 1 ) / ThreadblockShape::kM; + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + return set(args, grid_shape, workspace); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// If true, kernel supports split-K as a serial reduction + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator_, + /// Sparse matrix is A or not + bool IsASparse> +class EllGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kSplitKSerial = SplitKSerial; + static bool const kIsASparse = false; + + using UnderlyingOperator = EllGemm< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + SplitKSerial, + Operator, + kIsASparse + >; + + using UnderlyingArguments = typename UnderlyingOperator::Arguments; + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = UnderlyingOperator::kAlignmentC; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + const int* ell_idx_, + int ell_ncol_, + int ell_blocksize_, + int ell_base_idx_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ell_idx(ell_idx_), + ell_ncol(ell_ncol_), + ell_blocksize(ell_blocksize_), + ell_base_idx(ell_base_idx_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { } + }; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + EllGemm() { } + + /// Helper to construct a transposed equivalent for the underying GEMM operator + static UnderlyingArguments to_underlying_arguments(Arguments const &args) { + return UnderlyingArguments( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {args.ref_B.data(), args.ref_B.stride(0)}, + {args.ref_A.data(), args.ref_A.stride(0)}, + {args.ref_C.data(), args.ref_C.stride(0)}, + {args.ref_D.data(), args.ref_D.stride(0)}, + args.ell_idx, + args.ell_ncol, + args.ell_blocksize, + args.ell_base_idx, + args.epilogue, + args.split_k_slices + ); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, + args.split_k_slices); + + tiled_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + Status set(Arguments const &args, cutlass::gemm::GemmCoord const &grid_shape, void *workspace){ + // Initialize the Params structure + return underlying_operator_.set(to_underlying_arguments(args), grid_shape, workspace); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, + {ThreadblockShape::kM, args.ell_blocksize, ThreadblockShape::kK}, + args.split_k_slices); + + grid_shape.n() *= (args.ell_blocksize + ThreadblockShape::kN - 1 ) / ThreadblockShape::kN; + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + set(args, grid_shape, workspace); + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index fe265b5a..3f140627 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + \brief Template for a pipelined batch GEMM kernel. */ #pragma once diff --git a/include/cutlass/gemm/device/gemm_universal.h b/include/cutlass/gemm/device/gemm_universal.h index f2a09cd2..7a7aa869 100644 --- a/include/cutlass/gemm/device/gemm_universal.h +++ b/include/cutlass/gemm/device/gemm_universal.h @@ -58,6 +58,11 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// /*! + GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a given GEMM computation + (problem geometry and data references), it can be reused across different GEMM problems having the + geometry. (Once initialized, details regarding problem geometry and references to workspace memory + cannot be updated.) + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and batched array variants. */ diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 69f5ba5e..743a02de 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -109,7 +109,6 @@ public: using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using UnderlyingOperator = GemmUniversalBase; using Arguments = typename UnderlyingOperator::Arguments; @@ -160,10 +159,11 @@ public: return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } - /// Lightweight update given a subset of arguments - Status update(Arguments const &args, void *workspace = nullptr) { + /// Lightweight update given a subset of arguments. Problem geometry is assumed to + /// remain the same. + Status update(Arguments const &args) { - return underlying_operator_.update(to_underlying_arguments(args), workspace); + return underlying_operator_.update(to_underlying_arguments(args)); } /// Runs the kernel using initialized state. diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index f338950c..14b640fa 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -28,15 +28,15 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! +/*! \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + \brief The universal GEMM accommodates streamk, batched strided, and batched array variants. */ + #pragma once -//#include +#include #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -44,7 +44,6 @@ #include "cutlass/device_kernel.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" #include "cutlass/gemm/kernel/gemm_universal.h" #include "cutlass/gemm/kernel/default_gemm_universal.h" @@ -52,7 +51,7 @@ #include "cutlass/trace.h" -//////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { @@ -67,7 +66,7 @@ public: using GemmKernel = GemmKernel_; using ThreadblockShape = typename GemmKernel::Mma::Shape; - + using ElementA = typename GemmKernel::ElementA; using LayoutA = typename GemmKernel::LayoutA; using TensorRefA = TensorRef; @@ -83,7 +82,8 @@ public: using TensorRefC = TensorRef; using TensorRefD = TensorRef; - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + /// Numerical accumulation element type + using ElementAccumulator = typename GemmKernel::Mma::ElementC; using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; @@ -94,316 +94,285 @@ public: protected: - /// Kernel parameters object - typename GemmKernel::Params params_; + // + // Device properties (uniform across all instances of the current thread) + // -protected: + // Device ordinal + thread_local static int device_ordinal_; - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord &grid_tiled_shape, int &gemm_k_size, Arguments const &args) { + /// Device SM count + thread_local static int device_sms_; - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; + /// Kernel SM occupancy (in thread blocks) + thread_local static int sm_occupancy_; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.batch_count); - - gemm_k_size = args.problem_size.k(); - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + /// Initialize static thread-local members for the thread's current device, + /// if necessary. + static Status init_device_props() + { + CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); - int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + cudaError_t cudart_result; - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - } - -public: - - /// Constructs the GEMM. - GemmUniversalBase() { } - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { - - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); - - size_t workspace_bytes = 0; - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { - - // Split-K parallel always requires a temporary workspace - workspace_bytes = - sizeof(ElementC) * - size_t(args.batch_stride_D) * - size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { - - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + // Get current device ordinal + int current_ordinal; + cudart_result = cudaGetDevice(¤t_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; } - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + // Done if matches the current static member + if (current_ordinal == device_ordinal_) { + // Already initialized + return Status::kSuccess; + } - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - - return workspace_bytes; - } + // Update SM count member + cudart_result = cudaDeviceGetAttribute (&device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const &args) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); - - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - CUTLASS_TRACE_HOST( - " grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); - - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - - int max_active_blocks = -1; + // Update the kernel function's shared memory configuration for the current device int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) + { + // Requires more than 48KB: configure for extended, dynamic shared memory - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - if (smem_size <= (48 << 10)) { - - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - Kernel, - GemmKernel::kThreadCount, + cudart_result = cudaFuncSetAttribute( + Kernel2, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result == cudaSuccess) { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else { - - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - Kernel, - GemmKernel::kThreadCount, - 0); - - if (result != cudaSuccess) { - - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); - - return -1; + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; } - if (smem_capacity < 0) { - int device_idx = 0; - result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - return -1; - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - return -1; - } - - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } - - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - - return occupancy; - } - - CUTLASS_TRACE_HOST(" returning internal error"); - - return -1; - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - size_t workspace_bytes = get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) { - - if (!workspace) { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - - return Status::kErrorWorkspaceNull; - } - - if (args.mode == GemmUniversalMode::kGemm) { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - - return Status::kErrorInternal; - } - } - } - - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - // Initialize the Params structure - params_ = typename GemmKernel::Params( - args, - grid_tiled_shape, - gemm_k_size, - static_cast(workspace) - ); - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { + cudart_result = cudaFuncSetAttribute( + Kernel2, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); // 100% shared memory + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result)); return Status::kErrorInternal; } } - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const &args, void *workspace = nullptr) { - - CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace); - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) { - return Status::kErrorWorkspaceNull; + // Update SM occupancy member + cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &sm_occupancy_, + Kernel2, + GemmKernel::kThreadCount, + int(sizeof(typename GemmKernel::SharedStorage)), + cudaOccupancyDisableCachingOverride); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; } - - params_.update(args, workspace); - + + // Update device ordinal member on success + device_ordinal_ = current_ordinal; + + CUTLASS_TRACE_HOST(" " + "device_ordinal: (" << device_ordinal_ << "), " + "device_sms: (" << device_sms_ << "), " + "sm_occupancy: (" << sm_occupancy_ << ")"); + return Status::kSuccess; } + +protected: + + // + // Instance data members + // + + /// Kernel parameters + typename GemmKernel::Params params_; + + + /// Initialize params member + Status init_params(Arguments const &args) + { + // Initialize static device properties, if necessary + Status result = init_device_props(); + if (result != Status::kSuccess) { + return result; + } + + // Initialize params member + params_ = typename GemmKernel::Params(args, device_sms_, sm_occupancy_); + return Status::kSuccess; + } + +public: + + //--------------------------------------------------------------------------------------------- + // Stateless API + //--------------------------------------------------------------------------------------------- + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); + + // Initialize static kernel and device properties, if necessary. + Status result = init_device_props(); + if (result != Status::kSuccess) { + return result; + } + + dim3 grid = get_grid_shape(args); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) + { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + + /// Returns the workspace size (in bytes) needed for the problem + /// geometry expressed by these arguments + static size_t get_workspace_size(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args) != Status::kSuccess) { + return 0; + } + + // Get size from parameters + size_t workspace_bytes = base.params_.get_workspace_size(); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return workspace_bytes; + } + + + /// Returns the grid extents in thread blocks to launch + static dim3 get_grid_shape(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args) != Status::kSuccess) { + return dim3(0,0,0); + } + + // Get dims from parameters + dim3 grid_dims = base.params_.get_grid_dims(); + + CUTLASS_TRACE_HOST( + " tiled_shape: " << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); + + return grid_dims; + } + + + /// Returns the maximum number of active thread blocks per multiprocessor + static int maximum_active_blocks() + { + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + + // Initialize static device properties, if necessary + if (init_device_props() != Status::kSuccess) { + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); + return sm_occupancy_; + } + + + //--------------------------------------------------------------------------------------------- + // Stateful API + //--------------------------------------------------------------------------------------------- + + /// Initializes GEMM state from arguments and workspace memory + Status initialize( + Arguments const &args, + void *workspace, + cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize parameters from args + Status result = init_params(args); + if (result != Status::kSuccess) { + return result; + } + + // Assign and prepare workspace memory + return params_.init_workspace(workspace, stream); + } + + + /// Lightweight update given a subset of arguments. Problem geometry is assumed to + /// remain the same. + Status update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); + params_.update(args); + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { + Status run(cudaStream_t stream = nullptr) + { CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); - // // Configure grid and block dimensions - // - - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + dim3 block(GemmKernel::kThreadCount, 1, 1); + dim3 grid = params_.get_grid_dims(); - // // Launch kernel - // + CUTLASS_TRACE_HOST(" " + "grid: (" << grid << "), " + "block: (" << block << "), " + "SMEM: (" << smem_size << ")"); - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block - << "), SMEM: " << smem_size << " bytes"); + Kernel2<<>>(params_); - // Launch - cutlass::Kernel<<>>(params_); - - // // Query for errors - // cudaError_t result = cudaGetLastError(); - if (result != cudaSuccess) { CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); return Status::kErrorInternal; } - + return Status::kSuccess; } + /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { + Status operator()(cudaStream_t stream = nullptr) + { return run(stream); } + /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) { - + cudaStream_t stream = nullptr) + { Status status = initialize(args, workspace, stream); - + if (status == Status::kSuccess) { status = run(stream); } @@ -412,6 +381,24 @@ public: } }; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Static initializers +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device ordinal +template +thread_local int GemmUniversalBase::device_ordinal_ = -1; + +/// Device SM count +template +thread_local int GemmUniversalBase::device_sms_ = -1; + +/// Kernel SM occupancy (in thread blocks) +template +thread_local int GemmUniversalBase::sm_occupancy_ = -1; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace device diff --git a/include/cutlass/gemm/device/gemm_universal_with_broadcast.h b/include/cutlass/gemm/device/gemm_universal_with_broadcast.h index a0cd5a7d..ad52dd82 100644 --- a/include/cutlass/gemm/device/gemm_universal_with_broadcast.h +++ b/include/cutlass/gemm/device/gemm_universal_with_broadcast.h @@ -28,8 +28,10 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + /*! \file - \brief + \brief Template for a GEMM kernel that can broadcast bias vector in the + epigloue. */ #pragma once @@ -45,7 +47,7 @@ #include "cutlass/gemm/kernel/gemm_universal.h" #include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h" +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/device/gemm_universal_base.h" @@ -97,7 +99,7 @@ template < /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' typename EpilogueOutputOp_ = cutlass::epilogue::thread::LinearCombinationBiasElementwise< ElementC_, ElementAccumulator_, ElementAccumulator_, - ElementC_, ElementC_, 16 / sizeof(ElementC_)>, + ElementC_, ElementC_, 128 / cutlass::sizeof_bits::value>, /// Threadblock-level swizzling operator typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, /// Number of stages used in the pipelined mainloop @@ -123,7 +125,7 @@ template < > class GemmUniversalWithBroadcast : public GemmUniversalBase< - typename kernel::DefaultGemmWithBroadcastV2< + typename kernel::DefaultGemmWithBroadcast< ElementA_, LayoutA_, TransformA, @@ -166,7 +168,7 @@ class GemmUniversalWithBroadcast : static ComplexTransform const kTransformB = TransformB; using Base = GemmUniversalBase< - typename kernel::DefaultGemmWithBroadcastV2< + typename kernel::DefaultGemmWithBroadcast< ElementA_, LayoutA_, TransformA, diff --git a/include/cutlass/gemm/device/gemm_with_k_reduction.h b/include/cutlass/gemm/device/gemm_with_k_reduction.h index 254a7f96..1f99b3f0 100644 --- a/include/cutlass/gemm/device/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/device/gemm_with_k_reduction.h @@ -29,7 +29,8 @@ * **************************************************************************************************/ /*! \file - \brief + \brief Template for a GEMM kernel that can reduce one of the input matrix + into a vector along the K dimension. */ #pragma once diff --git a/include/cutlass/gemm/kernel/default_ell_gemm.h b/include/cutlass/gemm/kernel/default_ell_gemm.h new file mode 100644 index 00000000..c5b676bc --- /dev/null +++ b/include/cutlass/gemm/kernel/default_ell_gemm.h @@ -0,0 +1,837 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Default kernel-level Blocked-Ell sparse gemm operators. + This operator combines threadblock-scoped ELL MMA + with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +#include "cutlass/gemm/kernel/ell_gemm.h" +#include "cutlass/gemm/threadblock/default_ell_mma.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse> +struct DefaultEllGemm; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse +> +struct DefaultEllGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse +> +struct DefaultEllGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + SplitKSerial, + Operator, + IsASparse +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape, + WarpShape, + InstructionShape, + 2, + Operator + >::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse> +struct DefaultEllGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, + layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + SplitKSerial, Operator, IsASparse> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, Operator, + true>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse> +struct DefaultEllGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, + WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, 2, SplitKSerial, Operator, IsASparse> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, + InstructionShape, 2, Operator, true>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Volta architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse +> +struct DefaultEllGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm70, + ThreadblockShape, + WarpShape, + GemmShape<8, 8, 4>, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + SplitKSerial, + Operator, + IsASparse +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm70, + ThreadblockShape, + WarpShape, + GemmShape<8, 8, 4>, + 2, + Operator + >::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for SIMT +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse + > +struct DefaultEllGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + layout::RowMajor, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + GemmShape<1, 1, 1>, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + SplitKSerial, + Operator, + IsASparse> { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + arch::Sm50, + ThreadblockShape, + WarpShape, + GemmShape<1, 1, 1>, + 2, + Operator>::ThreadblockMma; + + static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; + static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages + int Stages, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse + > +struct DefaultEllGemm, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator, + IsASparse> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm80, + ThreadblockShape, WarpShape, GemmShape<1, 1, 1>, Stages, + Operator>::ThreadblockMma; + + static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; + static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for SIMT DP4A + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Layout type for C matrix operand + typename LayoutC, + /// Element type for C and D matrix operands + typename ElementC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse + > +struct DefaultEllGemm, + EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, + Operator, IsASparse> { + using InstructionShape = GemmShape<1, 1, 4>; + using ElementA = int8_t; + using ElementB = int8_t; + + using OperatorClass = arch::OpClassSimt; + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma::ThreadblockMma; + + static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount; + static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars"); + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + typename Mma::Operator, + EpilogueOutputOp, + kEpilogueElementsPerAccess + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +//////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Wmma Gemm Kernel +template < + ///< Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Sparse matrix is A or not + bool IsASparse + > +struct DefaultEllGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassWmmaTensorOp, + ArchTag, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + SplitKSerial, + Operator, + IsASparse> { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultEllMma< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, + arch::OpClassWmmaTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp< + ThreadblockShape, + typename Mma::Operator, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::EllGemm; +}; +//////////////////////////////////////////////////////////////////////////////// +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index 38efae7e..629bccab 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -135,6 +135,77 @@ template < struct DefaultGemm; //////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout +> +struct DefaultGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, SharedMemoryClear, GatherA, GatherB>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; + //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Ampere Architecture diff --git a/include/cutlass/gemm/kernel/default_gemm_complex.h b/include/cutlass/gemm/kernel/default_gemm_complex.h index 1fad1412..e26fcb45 100644 --- a/include/cutlass/gemm/kernel/default_gemm_complex.h +++ b/include/cutlass/gemm/kernel/default_gemm_complex.h @@ -119,6 +119,66 @@ struct DefaultGemmComplex; //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial + > +struct DefaultGemmComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementC, + layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, + layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, + WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< + ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::Gemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Ampere Architecture template < /// Element type for A matrix operand diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h index be9634e9..051e71a6 100644 --- a/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -49,6 +49,7 @@ #include "cutlass/numeric_types.h" #include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/kernel/gemm_universal_streamk.h" #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_complex.h" @@ -227,12 +228,26 @@ struct DefaultGemmUniversal< PermuteDLayout >::GemmKernel; - /// Define the kernel in terms of the default kernel - using GemmKernel = kernel::GemmUniversal< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - ThreadblockSwizzle - >; + /// Universal kernel without StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -336,12 +351,26 @@ struct DefaultGemmUniversal< false >::GemmKernel; - /// Define the kernel in terms of the default kernel - using GemmKernel = kernel::GemmUniversal< - typename DefaultGemmKernel::Mma, - typename DefaultGemmKernel::Epilogue, - ThreadblockSwizzle - >; + /// Universal kernel without StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversal< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase : + public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + SwizzleT> + {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h b/include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h deleted file mode 100644 index e1f1d8c3..00000000 --- a/include/cutlass/gemm/kernel/default_gemm_with_broadcast_v2.h +++ /dev/null @@ -1,242 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" - -#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast_v2.h" -#include "cutlass/epilogue/threadblock/epilogue_with_broadcast_v2.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// - typename Enable = void -> -struct DefaultGemmWithBroadcastV2 { - - using GemmBase = typename DefaultGemmUniversal< - ElementA_, LayoutA_, TransformA, kAlignmentA, - ElementB_, LayoutB_, TransformB, kAlignmentB, - ElementC_, LayoutC_, ElementAccumulator, - OperatorClass, - ArchTag, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator - >::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOpV2< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - GemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogueV2< - typename GemmBase::Mma, - Epilogue, - ThreadblockSwizzle - >; -}; - - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Parital specialization: ArchTag = cutlass::arch::Sm70 -/// -/// -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator, - /// - typename Enable -> -struct DefaultGemmWithBroadcastV2< - ElementA_, LayoutA_, TransformA, kAlignmentA, - ElementB_, LayoutB_, TransformB, kAlignmentB, - ElementC_, LayoutC_, - ElementAccumulator, - OperatorClass, - cutlass::arch::Sm70, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator, - Enable - > { - - using GemmBase = typename DefaultGemmUniversal< - ElementA_, LayoutA_, TransformA, kAlignmentA, - ElementB_, LayoutB_, TransformB, kAlignmentB, - ElementC_, LayoutC_, ElementAccumulator, - OperatorClass, - cutlass::arch::Sm70, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOutputOp, - ThreadblockSwizzle, - Stages, - Operator - >::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOpV2< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - GemmBase::Epilogue::kElementsPerAccess - >::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogueV2< - typename GemmBase::Mma, - Epilogue, - ThreadblockSwizzle - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_rank_2k.h b/include/cutlass/gemm/kernel/default_rank_2k.h index d421d85d..7dcc93cc 100644 --- a/include/cutlass/gemm/kernel/default_rank_2k.h +++ b/include/cutlass/gemm/kernel/default_rank_2k.h @@ -120,6 +120,84 @@ template < BlasMode BlasMode_ = BlasMode::kSymmetric> struct DefaultRank2K; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultRank2K< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC,layout::RowMajor, FillModeC, + ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, + Operator> { + /// Define the threadblock-scoped matrix multiply-accumulate (A x BT) + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, + kAlignmentA, + ElementB, typename layout::LayoutTranspose::type, + kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + /// Define the threadblock-scoped matrix multiply-accumulate (B x AT) + using Mma2 = typename cutlass::gemm::threadblock::DefaultMma< + ElementB, LayoutB, + kAlignmentB, + ElementA, typename layout::LayoutTranspose::type, + kAlignmentA, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< + ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; + + /// Define the kernel-level Rank2K operator. + using Rank2Kkernel = kernel::Rank2KUniversal; +}; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_rank_2k_complex.h b/include/cutlass/gemm/kernel/default_rank_2k_complex.h index 8255cef2..eb93ef58 100644 --- a/include/cutlass/gemm/kernel/default_rank_2k_complex.h +++ b/include/cutlass/gemm/kernel/default_rank_2k_complex.h @@ -163,6 +163,170 @@ template <> }; } + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture complex datatype (symmetric) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultRank2KComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementC, + layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, TransformB, Operator, SplitKSerial, BlasMode::kSymmetric> { + + static BlasMode const kBlasMode = BlasMode::kSymmetric; + + /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) + using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, + ElementB, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the threadblock-scoped matrix multiply-accumulate (B x A^T) + using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementB, LayoutB, + ElementA, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< + ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; + + /// Define the kernel-level Rank2K operator. + using Rank2Kkernel = kernel::Rank2KUniversal; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture complex datatype (hermitian) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultRank2KComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementC, + layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, TransformB, Operator, SplitKSerial, BlasMode::kHermitian> { + + static BlasMode const kBlasMode = BlasMode::kHermitian; + + // Complex transform for input A and B matrices (function on input layout) + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using TransposedComplexTransform = detail::Rank2KTransposedComplexTransform< + LayoutA, LayoutB, + TransformA, TransformB, + kBlasMode>; + + // Complex transform on operandA and operandB (function of blas3 computation) + static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; + static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; + + /// Define the threadblock-scoped matrix multiply-accumulate (A x B^H) + using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, + ElementB, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; + + /// Define the threadblock-scoped matrix multiply-accumulate (B x A^H) + using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementB, LayoutB, + ElementA, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< + ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; + + /// Define the kernel-level Rank2K operator. + using Rank2Kkernel = kernel::Rank2KUniversal; + +}; + //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Ampere Architecture complex datatype (symmetric) diff --git a/include/cutlass/gemm/kernel/default_rank_k.h b/include/cutlass/gemm/kernel/default_rank_k.h index 80960f2a..c42f122c 100644 --- a/include/cutlass/gemm/kernel/default_rank_k.h +++ b/include/cutlass/gemm/kernel/default_rank_k.h @@ -114,6 +114,68 @@ template < BlasMode BlasMode_ = BlasMode::kSymmetric> struct DefaultRankK; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultRankK< + ElementA, LayoutA, kAlignmentA, + ElementC,layout::RowMajor, FillModeC, + ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, + Operator> { + /// Define the threadblock-scoped matrix multiply-accumulate (A x AT) + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, + kAlignmentA, + ElementA, typename layout::LayoutTranspose::type, + kAlignmentA, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpBlas3< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, BlasMode::kSymmetric>::Epilogue; + + /// Define the kernel-level Rank2 operator. + using RankKkernel = kernel::RankKUniversal; +}; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_rank_k_complex.h b/include/cutlass/gemm/kernel/default_rank_k_complex.h index 21d8607b..9d1a0345 100644 --- a/include/cutlass/gemm/kernel/default_rank_k_complex.h +++ b/include/cutlass/gemm/kernel/default_rank_k_complex.h @@ -155,6 +155,140 @@ template <> } //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Hopper Architecture complex datatype (symmetric) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultRankKComplex< + ElementA, LayoutA, ElementC, + layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, Operator, SplitKSerial, BlasMode::kSymmetric> { + + static BlasMode const kBlasMode = BlasMode::kSymmetric; + + /// Define the threadblock-scoped matrix multiply-accumulate (A x B^T) + using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, + ElementA, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + TransformA, TransformA, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< + ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; + + /// Define the kernel-level RankK operator. + using RankKkernel = kernel::RankKUniversal; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture complex datatype (hermitian) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for C and D matrix operands + typename ElementC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultRankKComplex< + ElementA, LayoutA, ElementC, + layout::RowMajor, FillModeC, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, Operator, SplitKSerial, BlasMode::kHermitian> { + + static BlasMode const kBlasMode = BlasMode::kHermitian; + + // Complex transform for input A and B matrices (function on input layout) + static ComplexTransform const kTransformA = TransformA; + + using TransposedComplexTransform = detail::RankKTransposedComplexTransform< + LayoutA, + TransformA, + kBlasMode>; + + // Complex transform on operandA and operandB (function of blas3 computation) + static ComplexTransform const kTransformOperandA = TransposedComplexTransform::kTransformA; + static ComplexTransform const kTransformOperandB = TransposedComplexTransform::kTransformB; + + /// Define the threadblock-scoped matrix multiply-accumulate (A x A^H) + using Mma = typename cutlass::gemm::threadblock::DefaultMultistageMmaComplex< + ElementA, LayoutA, + ElementA, typename layout::LayoutTranspose::type, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + kTransformOperandA, kTransformOperandB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOpBlas3< + ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator, kBlasMode>::Epilogue; + + /// Define the kernel-level RankK operator. + using RankKkernel = kernel::RankKUniversal; + +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Ampere Architecture complex datatype (symmetric) template < /// Element type for A matrix operand diff --git a/include/cutlass/gemm/kernel/default_symm.h b/include/cutlass/gemm/kernel/default_symm.h index 079bb09b..59555141 100755 --- a/include/cutlass/gemm/kernel/default_symm.h +++ b/include/cutlass/gemm/kernel/default_symm.h @@ -123,6 +123,101 @@ template < BlasMode BlasMode_ = BlasMode::kSymmetric> struct DefaultSymm; +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Side Mode for A (kLeft or kRight) + SideMode kSideModeA, + /// Fill Mode for A (kLower or kUpper) + FillMode kFillModeA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultSymm< + ElementA, LayoutA, kSideModeA, kFillModeA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC,layout::RowMajor, + ElementAccumulator, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, + Operator> { + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - with diagonal: alpha * A * B or alpha * B * A + static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; + using Mma1 = typename cutlass::gemm::threadblock::DefaultTrmm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + kSideModeA, kFillModeA, kDiagTypeMma1, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, Operator>::ThreadblockMma; + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT + static const DiagType kDiagTypeMma2 = DiagType::kZero; + using LayoutAMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + typename layout::LayoutTranspose::type, + LayoutA + >::type; + using LayoutBMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + LayoutB, + typename layout::LayoutTranspose::type + >::type; + using Mma2 = typename cutlass::gemm::threadblock::DefaultTrmm< + ElementA, LayoutAMma2, kAlignmentA, + ElementB, LayoutBMma2, kAlignmentB, + kSideModeA, InvertFillMode::mode, kDiagTypeMma2, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma1::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level SYMM/HEMM operator. + using SymmKernel = kernel::SymmUniversal; +}; //////////////////////////////////////////////////////////////////////////////// @@ -221,7 +316,6 @@ struct DefaultSymm< }; //////////////////////////////////////////////////////////////////////////////// - } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/kernel/default_symm_complex.h b/include/cutlass/gemm/kernel/default_symm_complex.h index a8b5d218..3313be69 100755 --- a/include/cutlass/gemm/kernel/default_symm_complex.h +++ b/include/cutlass/gemm/kernel/default_symm_complex.h @@ -117,6 +117,199 @@ struct DefaultSymmComplex; //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Hopper Architecture complex datatype (symmetric) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Side Mode for A (kLeft or kRight) + SideMode kSideModeA, + /// Fill Mode for A (kLower or kUpper) + FillMode kFillModeA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultSymmComplex< + ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, + layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + Operator, SplitKSerial, BlasMode::kSymmetric> { + + static BlasMode const kBlasMode = BlasMode::kSymmetric; + // Complex Transform don't appply to A or B for SYMM + static ComplexTransform const TransformA = ComplexTransform::kNone; + static ComplexTransform const TransformB = ComplexTransform::kNone; + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - with diagonal: alpha * A * B or alpha * B * A + static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + kSideModeA, kFillModeA, kDiagTypeMma1, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - withOUT diagonal: alpha * AT * B or alpha * B * AT + static const DiagType kDiagTypeMma2 = DiagType::kZero; + using LayoutAMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + typename layout::LayoutTranspose::type, + LayoutA + >::type; + using LayoutBMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + LayoutB, + typename layout::LayoutTranspose::type + >::type; + using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< + ElementA, LayoutAMma2, + ElementB, LayoutBMma2, + kSideModeA, InvertFillMode::mode, kDiagTypeMma2, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< + ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator>::Epilogue; + + /// Define the kernel-level Symm operator. + using SymmKernel = kernel::SymmUniversal; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Hopper Architecture complex datatype (hermitian) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Side Mode for A (kLeft or kRight) + SideMode kSideModeA, + /// Fill Mode for A (kLower or kUpper) + FillMode kFillModeA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial> +struct DefaultSymmComplex< + ElementA, LayoutA, kSideModeA, kFillModeA, ElementB, LayoutB, ElementC, + layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, + Operator, SplitKSerial, BlasMode::kHermitian> { + + static BlasMode const kBlasMode = BlasMode::kHermitian; + + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - with diagonal: alpha * A * B or alpha * B * A + static const DiagType kDiagTypeMma1 = DiagType::kNonUnit; + static ComplexTransform const TransformAMma1 = ComplexTransform::kNone; + static ComplexTransform const TransformBMma1 = ComplexTransform::kNone; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + kSideModeA, kFillModeA, kDiagTypeMma1, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, TransformAMma1, TransformBMma1, Operator, BlasMode::kHermitian>::ThreadblockMma; + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + /// TRMM - withOUT diagonal - with conjugate transpose: alpha * AT * B or alpha * B * AT + static const DiagType kDiagTypeMma2 = DiagType::kZero; + using LayoutAMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + typename layout::LayoutTranspose::type, + LayoutA + >::type; + using LayoutBMma2 = typename platform::conditional< + (kSideModeA == SideMode::kLeft), + LayoutB, + typename layout::LayoutTranspose::type + >::type; + static ComplexTransform const TransformAMma2 = (kSideModeA == SideMode::kLeft) ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const TransformBMma2 = (kSideModeA == SideMode::kLeft) ? + ComplexTransform::kNone : ComplexTransform::kConjugate; + + using Mma2 = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< + ElementA, LayoutAMma2, + ElementB, LayoutBMma2, + kSideModeA, InvertFillMode::mode, kDiagTypeMma2, + ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, + Stages, TransformAMma2, TransformBMma2, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< + ThreadblockShape, typename Mma1::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator>::Epilogue; + + /// Define the kernel-level Symm operator. + using SymmKernel = kernel::SymmUniversal; + +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Ampere Architecture complex datatype (symmetric) template < /// Element type for A matrix operand @@ -310,7 +503,6 @@ struct DefaultSymmComplex< //////////////////////////////////////////////////////////////////////////////// - } // namespace kernel } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/kernel/default_trmm.h b/include/cutlass/gemm/kernel/default_trmm.h index a3d6b9ef..3a8c08e1 100644 --- a/include/cutlass/gemm/kernel/default_trmm.h +++ b/include/cutlass/gemm/kernel/default_trmm.h @@ -124,6 +124,76 @@ struct DefaultTrmm; //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Side Mode for the kernel + SideMode kSideMode, + /// Fill Mode for the triangular matrix + FillMode kFillMode, + /// Diag Type for the triangular matrix + DiagType kDiagType, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultTrmm { + + /// Define the threadblock-scoped triagular matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultTrmm< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + kSideMode, kFillMode, kDiagType, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level TRMM operator. + using TrmmKernel = kernel::TrmmUniversal; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Ampere Architecture template < /// Element type for A matrix operand diff --git a/include/cutlass/gemm/kernel/default_trmm_complex.h b/include/cutlass/gemm/kernel/default_trmm_complex.h index 9dcb5a2f..8194e055 100644 --- a/include/cutlass/gemm/kernel/default_trmm_complex.h +++ b/include/cutlass/gemm/kernel/default_trmm_complex.h @@ -122,6 +122,74 @@ struct DefaultTrmmComplex; //////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Hopper Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Side Mode for the kernel + SideMode kSideMode, + /// Fill Mode for the triangular matrix + FillMode kFillMode, + /// Diag Type for the triangular matrix + DiagType kDiagType, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Multiply-add operator + // (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) + typename Operator, + /// If true, kernel is configured to support serial reduction in the epilogue + bool SplitKSerial + > +struct DefaultTrmmComplex< + ElementA, LayoutA, ElementB, LayoutB, + kSideMode, kFillMode, kDiagType, + ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, + arch::Sm90, ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, Operator, SplitKSerial> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMultistageTrmmComplex< + ElementA, LayoutA, ElementB, LayoutB, + kSideMode, kFillMode, kDiagType, + ElementAccumulator,layout::RowMajor, arch::OpClassTensorOp, arch::Sm90, ThreadblockShape, + WarpShape, InstructionShape, Stages, TransformA, TransformB, Operator>::ThreadblockMma; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueComplexTensorOp< + ThreadblockShape, typename Mma::Operator, 1, EpilogueOutputOp, + EpilogueOutputOp::kCount, Operator>::Epilogue; + + /// Define the kernel-level TRMM operator. + using TrmmKernel = kernel::TrmmUniversal; +}; + +//////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Ampere Architecture template < /// Element type for A matrix operand diff --git a/include/cutlass/gemm/kernel/ell_gemm.h b/include/cutlass/gemm/kernel/ell_gemm.h new file mode 100644 index 00000000..87015e65 --- /dev/null +++ b/include/cutlass/gemm/kernel/ell_gemm.h @@ -0,0 +1,830 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a Block-Ell sparse gemm kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/ell_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool IsASparse ///! If true, A is sparse matrix +> +struct EllGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_iterations; + int gemm_k_size; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + const int* ell_idx, + int ell_ncol, + int ell_blocksize, + int ell_base_idx, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + ell_idx(ell_idx), + ell_ncol(ell_ncol), + ell_blocksize(ell_blocksize), + ell_base_idx(ell_base_idx) + { + + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union{ + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + typename cutlass::transform::threadblock::ell::SharedStorage ell; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + EllGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kM - 1 ) / Mma::Shape::kM; + int ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; + int tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // skip computation if matrix is 0 + if (params.ell_ncol > 0) { + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + ell_block_offset_m * params.ell_blocksize + + tile_offset_m * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + int ell_idx_start = + (threadblock_tile_offset.m() / tile_in_ell_block) * + (params.ell_ncol / params.ell_blocksize); + const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + problem_size_k = min(problem_size_k, params.ell_ncol); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Define coef for ELL index depending on LayoutB + int ell_stride = iterator_B.get_stride(); + + typename cutlass::transform::threadblock::ell::Iterator ell_iterator( + shared_storage.ell, + ell_idx_ptr, + params.ell_blocksize, + params.ell_base_idx, + Mma::Shape::kK, + problem_size_k, + ell_stride, + thread_idx + ); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // check if index computations can be skipped + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); + constexpr bool is_multiple_alignment = + (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); + const bool is_specialized_blocksize = + ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 + && params.ell_blocksize >= Mma::Shape::kK; + // Compute threadblock-scoped matrix multiply-add + if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { + mma.operator()( + gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); + } + else { + mma.operator()( + gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); + } + } + } // if (params.ell_ncols > 0) + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + ell_block_offset_m = threadblock_tile_offset.m() / tile_in_ell_block; + tile_offset_m = threadblock_tile_offset.m() % tile_in_ell_block; + + //assume identity swizzle + MatrixCoord threadblock_offset( + ell_block_offset_m * params.ell_blocksize + + tile_offset_m * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + //avoid out of bounds + MatrixCoord threadblock_extent( + min(params.problem_size.m(), + ell_block_offset_m * params.ell_blocksize + + min((tile_offset_m + 1) * Mma::Shape::kM, params.ell_blocksize)), + min(params.problem_size.n(), + (threadblock_tile_offset.n()+1) * Mma::Shape::kN) + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + threadblock_extent, + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + threadblock_extent, + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +// B is Sparse +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct EllGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_iterations; + int gemm_k_size; + const int* ell_idx; + int ell_ncol; + int ell_blocksize; + int ell_base_idx; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + const int* ell_idx, + int ell_ncol, + int ell_blocksize, + int ell_base_idx, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + ell_idx(ell_idx), + ell_ncol(ell_ncol), + ell_blocksize(ell_blocksize), + ell_base_idx(ell_base_idx) + { + + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union{ + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + typename cutlass::transform::threadblock::ell::SharedStorage ell; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + EllGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int tile_in_ell_block = (params.ell_blocksize + Mma::Shape::kN - 1 ) / Mma::Shape::kN; + int ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; + int tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // skip computation if matrix is 0 + if (params.ell_ncol > 0) { + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + ell_block_offset_n * params.ell_blocksize + + tile_offset_n * Mma::Shape::kN, + }; + + int ell_idx_start = + (threadblock_tile_offset.n() / tile_in_ell_block) * + (params.ell_ncol / params.ell_blocksize); + const int* ell_idx_ptr = &(params.ell_idx[ell_idx_start]); + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + problem_size_k = min(problem_size_k, params.ell_ncol); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Define coef for ELL index depending on LayoutA + int ell_stride = iterator_A.get_stride(); + + typename cutlass::transform::threadblock::ell::Iterator ell_iterator( + shared_storage.ell, + ell_idx_ptr, + params.ell_blocksize, + params.ell_base_idx, + Mma::Shape::kK, + problem_size_k, + ell_stride, + thread_idx + ); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // check if index computations can be skipped + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); + constexpr bool is_multiple_alignment = + (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); + const bool is_specialized_blocksize = + ((params.ell_blocksize) & (params.ell_blocksize-1)) == 0 + && params.ell_blocksize >= Mma::Shape::kK; + // Compute threadblock-scoped matrix multiply-add + if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { + mma.operator()( + gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); + } + else { + mma.operator()( + gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); + } + } + } // if (params.ell_ncols > 0) + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + ell_block_offset_n = threadblock_tile_offset.n() / tile_in_ell_block; + tile_offset_n = threadblock_tile_offset.n() % tile_in_ell_block; + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + ell_block_offset_n * params.ell_blocksize + + tile_offset_n * Mma::Shape::kN + ); + + //avoid out of bounds + MatrixCoord threadblock_extent( + min(params.problem_size.m(), + (threadblock_tile_offset.m()+1) * Mma::Shape::kM), + min(params.problem_size.n(), + ell_block_offset_n * params.ell_blocksize + + min((tile_offset_n + 1) * Mma::Shape::kN, params.ell_blocksize)) + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + threadblock_extent, + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + threadblock_extent, + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/kernel/gemm_grouped.h b/include/cutlass/gemm/kernel/gemm_grouped.h index c02d3ff9..c5f24917 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/include/cutlass/gemm/kernel/gemm_grouped.h @@ -315,13 +315,6 @@ public: static Status can_implement(Arguments const &args) { return Status::kSuccess; } - - static size_t get_extra_workspace_size( - Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } /// Executes one GEMM CUTLASS_DEVICE diff --git a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h index 51fec120..b32d716a 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h @@ -50,11 +50,22 @@ namespace kernel { namespace detail { // Helper for correctly representing problem sizes in grouped kernels -template +template < + typename ThreadblockShape, + bool Transposed +> struct GemmGroupedProblemSizeHelper { static bool const kTransposed = Transposed; + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), + 1); + } + CUTLASS_HOST_DEVICE static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { if (kTransposed) { @@ -77,7 +88,7 @@ template struct GemmGroupedProblemVisitor : public GroupedProblemVisitor< - detail::GemmGroupedProblemSizeHelper, + detail::GemmGroupedProblemSizeHelper, ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, @@ -85,7 +96,7 @@ struct GemmGroupedProblemVisitor : public GroupedProblemVisitor< static bool const kTransposed = Transposed; - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; using Base = GroupedProblemVisitor; using Params = typename Base::Params; using SharedStorage = typename Base::SharedStorage; diff --git a/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h index 4b2d90bb..916adcd7 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h +++ b/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h @@ -335,13 +335,6 @@ public: return Status::kSuccess; } - static size_t get_extra_workspace_size( - Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h index f83f5f6b..96a79fa7 100644 --- a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h +++ b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h @@ -41,6 +41,7 @@ #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/gemm/kernel/params_universal_base.h" #include "cutlass/layout/matrix.h" @@ -104,16 +105,12 @@ public: // /// Argument structure - struct Arguments { - + struct Arguments : UniversalArgumentsBase + { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; void const * ptr_A; @@ -132,7 +129,6 @@ public: int64_t batch_stride_gamma; int64_t batch_stride_beta; int64_t batch_stride_C; - int64_t batch_stride_D; typename LayoutA::Stride stride_a; typename LayoutB::Stride stride_b; @@ -161,14 +157,13 @@ public: // Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_var(nullptr), ptr_mean(nullptr), ptr_gamma(nullptr), ptr_beta(nullptr), ptr_gather_A_indices(nullptr), ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr) {} + ptr_scatter_D_indices(nullptr) + {} /// constructs an arguments structure Arguments( @@ -202,31 +197,27 @@ public: typename LayoutC::Stride stride_d, int const *ptr_gather_A_indices = nullptr, int const *ptr_gather_B_indices = nullptr, - int const *ptr_scatter_D_indices = nullptr - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + int const *ptr_scatter_D_indices = nullptr) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_var(ptr_var), ptr_mean(ptr_mean), ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), + batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), + batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), + lda(0), ldb(0), ldc(0), ldd(0), + ld_var(0), ld_mean(0), + ld_gamma(0), ld_beta(0), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), stride_var(stride_var), stride_mean(stride_mean), stride_gamma(stride_gamma), stride_beta(stride_beta), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { - lda = 0; - ldb = 0; - ldc = 0; - ldd = 0; - ld_var = 0; - ld_mean = 0; - ld_gamma = 0; - ld_beta = 0; + ptr_scatter_D_indices(ptr_scatter_D_indices) + { CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); - } + } /// constructs an arguments structure Arguments( @@ -260,23 +251,22 @@ public: typename LayoutC::Stride::LongIndex ldd, int const *ptr_gather_A_indices = nullptr, int const *ptr_gather_B_indices = nullptr, - int const *ptr_scatter_D_indices = nullptr - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + int const *ptr_scatter_D_indices = nullptr) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_var(ptr_var), ptr_mean(ptr_mean), ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_var(ld_var), ld_mean(ld_mean), ld_gamma(ld_gamma), ld_beta(ld_beta), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { + ptr_scatter_D_indices(ptr_scatter_D_indices) + { stride_a = make_Coord(lda); stride_b = make_Coord(ldb); stride_c = make_Coord(ldc); @@ -286,7 +276,7 @@ public: stride_gamma = make_Coord(ld_gamma); stride_beta = make_Coord(ld_beta); CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); - } + } /// Returns arguments for the transposed problem Arguments transposed_problem() const { @@ -303,17 +293,30 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; + + // + // Data members + // - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; typename Epilogue::OutputTileIterator::Params params_C; @@ -321,10 +324,6 @@ public: typename EpilogueOutputOp::Params output_op; - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - void * ptr_A; void * ptr_B; void * ptr_var; @@ -341,65 +340,30 @@ public: int64_t batch_stride_gamma; int64_t batch_stride_beta; int64_t batch_stride_C; - int64_t batch_stride_D; int * ptr_gather_A_indices; int * ptr_gather_B_indices; int * ptr_scatter_D_indices; - int *semaphore; - // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_var(nullptr), - ptr_mean(nullptr), - ptr_gamma(nullptr), - ptr_beta(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_var(0), - batch_stride_mean(0), - batch_stride_C(0), - batch_stride_D(0), - ptr_gather_A_indices(nullptr), - ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr), - semaphore(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), ptr_var(const_cast(args.ptr_var)), @@ -415,19 +379,15 @@ public: batch_stride_gamma(args.batch_stride_gamma), batch_stride_beta(args.batch_stride_beta), batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), - ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)), - semaphore(static_cast(workspace)) { - - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) + {} + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { ptr_A = const_cast(args.ptr_A); ptr_B = const_cast(args.ptr_B); ptr_var = const_cast(args.ptr_var); @@ -441,22 +401,13 @@ public: ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_var = args.batch_stride_var; - batch_stride_mean = args.batch_stride_mean; - batch_stride_gamma = args.batch_stride_gamma; - batch_stride_beta = args.batch_stride_beta; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; - output_op = args.epilogue; - semaphore = static_cast(workspace); CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; @@ -466,12 +417,9 @@ public: public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmLayernormMainloopFusion() { } - /// Determines whether kernel satisfies alignment static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { @@ -555,12 +503,23 @@ public: return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { +public: - return 0; + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmLayernormMainloopFusion op; + op(params, shared_storage); } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index 7b85fdbd..adedd3e9 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -41,6 +41,7 @@ #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/gemm/kernel/params_universal_base.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -105,16 +106,12 @@ public: // /// Argument structure - struct Arguments { - + struct Arguments : UniversalArgumentsBase + { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; void const * ptr_A_real; @@ -144,17 +141,13 @@ public: int64_t batch_stride_B_imag; int64_t batch_stride_C; int64_t batch_stride_C_imag; - int64_t batch_stride_D; int64_t batch_stride_D_imag; - // // Methods // - Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), + Arguments() : ptr_A_real(nullptr), ptr_A_imag(nullptr), ptr_B_real(nullptr), @@ -163,7 +156,7 @@ public: ptr_C_imag(nullptr), ptr_D_real(nullptr), ptr_D_imag(nullptr) - { } + {} /// constructs an arguments structure Arguments( @@ -194,11 +187,9 @@ public: int64_t batch_stride_C = 0, int64_t batch_stride_C_imag = 0, int64_t batch_stride_D = 0, - int64_t batch_stride_D_imag = 0 - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + int64_t batch_stride_D_imag = 0) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), ptr_A_real(ptr_A_real), ptr_A_imag(ptr_A_imag), @@ -222,10 +213,8 @@ public: batch_stride_B_imag(batch_stride_B_imag), batch_stride_C(batch_stride_C), batch_stride_C_imag(batch_stride_C_imag), - batch_stride_D(batch_stride_D), - batch_stride_D_imag(batch_stride_D_imag) { - - } + batch_stride_D_imag(batch_stride_D_imag) + {} /// Returns arguments for the transposed problem Arguments transposed_problem() const { @@ -243,16 +232,30 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; + + // + // Data members + // + typename Mma::IteratorA::Params params_A_real; typename Mma::IteratorA::Params params_A_imag; typename Mma::IteratorB::Params params_B_real; @@ -264,10 +267,6 @@ public: typename EpilogueOutputOp::Params output_op; - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - void * ptr_A_real; void * ptr_A_imag; void * ptr_B_real; @@ -278,54 +277,28 @@ public: void * ptr_D_imag; int64_t batch_stride_A; - int64_t batch_stride_A_imag; int64_t batch_stride_B; - int64_t batch_stride_B_imag; int64_t batch_stride_C; + + int64_t batch_stride_A_imag; + int64_t batch_stride_B_imag; int64_t batch_stride_C_imag; - int64_t batch_stride_D; int64_t batch_stride_D_imag; - int *semaphore; - // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - batch_count(0), - gemm_k_size(0), - swizzle_log_tile(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A_real(nullptr), - ptr_A_imag(nullptr), - ptr_B_real(nullptr), - ptr_B_imag(nullptr), - ptr_C_real(nullptr), - ptr_C_imag(nullptr), - ptr_D_real(nullptr), - ptr_D_imag(nullptr), - batch_stride_A(0), - batch_stride_A_imag(0), - batch_stride_B(0), - batch_stride_B_imag(0), - batch_stride_C(0), - batch_stride_C_imag(0), - batch_stride_D(0), - batch_stride_D_imag(0), - semaphore(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), params_A_real(args.lda_real), params_A_imag(args.lda_imag), params_B_real(args.ldb_real), @@ -335,9 +308,6 @@ public: params_D_real(args.ldd_real), params_D_imag(args.ldd_imag), output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A_real(const_cast(args.ptr_A_real)), ptr_A_imag(const_cast(args.ptr_A_imag)), ptr_B_real(const_cast(args.ptr_B_real)), @@ -347,21 +317,32 @@ public: ptr_D_real(args.ptr_D_real), ptr_D_imag(args.ptr_D_imag), batch_stride_A(args.batch_stride_A), - batch_stride_A_imag(args.batch_stride_A_imag), batch_stride_B(args.batch_stride_B), - batch_stride_B_imag(args.batch_stride_B_imag), batch_stride_C(args.batch_stride_C), + batch_stride_A_imag(args.batch_stride_A_imag), + batch_stride_B_imag(args.batch_stride_B_imag), batch_stride_C_imag(args.batch_stride_C_imag), - batch_stride_D(args.batch_stride_D), - batch_stride_D_imag(args.batch_stride_D_imag), - semaphore(static_cast(workspace)) { + batch_stride_D_imag(args.batch_stride_D_imag) + {} + /// Returns the workspace size (in bytes) needed for this problem geometry + size_t get_workspace_size() const + { + size_t workspace_bytes = ParamsBase::get_workspace_size(); + if (this->mode == GemmUniversalMode::kGemmSplitKParallel) + { + // Double the size returned by the base class because we need to + // accumulate two ElementC components + workspace_bytes *= 2; + } + + return workspace_bytes; } - void update( - Arguments const &args, - void *workspace = nullptr) { - + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { ptr_A_real = const_cast(args.ptr_A_real); ptr_A_imag = const_cast(args.ptr_A_imag); @@ -374,21 +355,11 @@ public: ptr_D_real = const_cast(args.ptr_D_real); ptr_D_imag = const_cast(args.ptr_D_imag); - batch_stride_A = args.batch_stride_A; - batch_stride_A_imag = args.batch_stride_A_imag; - batch_stride_B = args.batch_stride_B; - batch_stride_B_imag = args.batch_stride_B_imag; - batch_stride_C = args.batch_stride_C; - batch_stride_C_imag = args.batch_stride_C_imag; - batch_stride_D = args.batch_stride_D; - batch_stride_D_imag = args.batch_stride_D_imag; - output_op = args.epilogue; - - semaphore = static_cast(workspace); } }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; @@ -398,15 +369,12 @@ public: public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmPlanarComplex() { } - /// Determines whether kernel satisfies alignment - static Status can_implement(Arguments const &args) { - + static Status can_implement(Arguments const &args) + { static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; @@ -440,12 +408,23 @@ public: return Status::kSuccess; } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { +public: - return 0; + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmPlanarComplex op; + op(params, shared_storage); } + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 1360a2c0..79486c8d 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -41,6 +41,7 @@ #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/gemm/kernel/params_universal_base.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -105,16 +106,12 @@ public: // /// Argument structure - struct Arguments { - + struct Arguments : UniversalArgumentsBase + { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; int const *ptr_M; @@ -142,15 +139,11 @@ public: typename LayoutC::Stride::Index ldd_real; typename LayoutC::Stride::Index ldd_imag; - int64_t batch_stride_D; // unused - // // Methods // Arguments(): - mode(GemmUniversalMode::kArray), - batch_count(1), ptr_M(nullptr), ptr_N(nullptr), ptr_K(nullptr), @@ -161,9 +154,8 @@ public: ptr_C_real(nullptr), ptr_C_imag(nullptr), ptr_D_real(nullptr), - ptr_D_imag(nullptr), - batch_stride_D(0) - { } + ptr_D_imag(nullptr) + {} /// constructs an arguments structure Arguments( @@ -188,11 +180,9 @@ public: typename LayoutC::Stride::Index ldc_real, typename LayoutC::Stride::Index ldc_imag, typename LayoutC::Stride::Index ldd_real, - typename LayoutC::Stride::Index ldd_imag - ): - mode(GemmUniversalMode::kArray), - problem_size(problem_size), - batch_count(batch_count), + typename LayoutC::Stride::Index ldd_imag) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), ptr_M(ptr_M), ptr_N(ptr_N), @@ -212,10 +202,8 @@ public: ldc_real(ldc_real), ldc_imag(ldc_imag), ldd_real(ldd_real), - ldd_imag(ldd_imag), - batch_stride_D(0) { - - } + ldd_imag(ldd_imag) + {} /// Returns arguments for the transposed problem Arguments transposed_problem() const { @@ -232,15 +220,30 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; + + // + // Data members + // + typename Mma::IteratorA::Params params_A_real; typename Mma::IteratorA::Params params_A_imag; typename Mma::IteratorB::Params params_B_real; @@ -249,11 +252,9 @@ public: typename Epilogue::OutputTileIterator::Params params_C_imag; typename Epilogue::OutputTileIterator::Params params_D_real; typename Epilogue::OutputTileIterator::Params params_D_imag; - + typename EpilogueOutputOp::Params output_op; - int batch_count; - int const *ptr_M; int const *ptr_N; int const *ptr_K; @@ -268,35 +269,19 @@ public: void * const * ptr_D_imag; // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - batch_count(0), - swizzle_log_tile(0), - ptr_M(nullptr), - ptr_N(nullptr), - ptr_K(nullptr), - ptr_A_real(nullptr), - ptr_A_imag(nullptr), - ptr_B_real(nullptr), - ptr_B_imag(nullptr), - ptr_C_real(nullptr), - ptr_C_imag(nullptr), - ptr_D_real(nullptr), - ptr_D_imag(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size = 0, // ignored - void *workspace = nullptr // ignored - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), ptr_M(args.ptr_M), ptr_N(args.ptr_N), ptr_K(args.ptr_K), @@ -309,7 +294,6 @@ public: params_D_real(args.ldd_real), params_D_imag(args.ldd_imag), output_op(args.epilogue), - batch_count(args.batch_count), ptr_A_real(args.ptr_A_real), ptr_A_imag(args.ptr_A_imag), ptr_B_real(args.ptr_B_real), @@ -317,14 +301,13 @@ public: ptr_C_real(args.ptr_C_real), ptr_C_imag(args.ptr_C_imag), ptr_D_real(args.ptr_D_real), - ptr_D_imag(args.ptr_D_imag) { - - } - - void update( - Arguments const &args, - void *workspace = nullptr) { + ptr_D_imag(args.ptr_D_imag) + {} + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { ptr_M = args.ptr_M; ptr_N = args.ptr_N; ptr_K = args.ptr_K; @@ -345,6 +328,7 @@ public: } }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; @@ -354,12 +338,9 @@ public: public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmPlanarComplexArray() { } - /// Determines whether kernel satisfies alignment static Status can_implement(Arguments const &args) { @@ -396,12 +377,24 @@ public: return Status::kSuccess; } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - return 0; +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmPlanarComplexArray op; + op(params, shared_storage); } - + + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 982f765d..8fafd874 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -37,12 +37,12 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" - #include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" #include "cutlass/trace.h" @@ -55,7 +55,7 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_ ///! Threadblock swizzling function > @@ -101,16 +101,12 @@ public: // /// Argument structure - struct Arguments { - + struct Arguments : UniversalArgumentsBase + { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; void const * ptr_A; @@ -121,7 +117,6 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; typename LayoutA::Stride stride_a; typename LayoutB::Stride stride_b; @@ -140,14 +135,13 @@ public: // // Methods // - - Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), + + Arguments(): ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_gather_A_indices(nullptr), ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr) {} + ptr_scatter_D_indices(nullptr) + {} /// constructs an arguments structure Arguments( @@ -169,23 +163,22 @@ public: typename LayoutC::Stride stride_d, int const *ptr_gather_A_indices = nullptr, int const *ptr_gather_B_indices = nullptr, - int const *ptr_scatter_D_indices = nullptr - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + int const *ptr_scatter_D_indices = nullptr) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { + ptr_scatter_D_indices(ptr_scatter_D_indices) + { lda = 0; ldb = 0; ldc = 0; ldd = 0; CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); - } + } /// constructs an arguments structure Arguments( @@ -209,26 +202,26 @@ public: int const *ptr_gather_B_indices = nullptr, int const *ptr_scatter_D_indices = nullptr ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), - ptr_scatter_D_indices(ptr_scatter_D_indices) { + ptr_scatter_D_indices(ptr_scatter_D_indices) + { stride_a = make_Coord(lda); stride_b = make_Coord(ldb); stride_c = make_Coord(ldc); stride_d = make_Coord(ldd); CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); - } + } /// Returns arguments for the transposed problem - Arguments transposed_problem() const { + Arguments transposed_problem() const + { Arguments args(*this); - + std::swap(args.problem_size.m(), args.problem_size.n()); std::swap(args.ptr_A, args.ptr_B); std::swap(args.lda, args.ldb); @@ -240,27 +233,36 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; + + // + // Data members + // - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; typename Epilogue::OutputTileIterator::Params params_C; typename Epilogue::OutputTileIterator::Params params_D; - - typename EpilogueOutputOp::Params output_op; - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; + typename EpilogueOutputOp::Params output_op; void * ptr_A; void * ptr_B; @@ -270,59 +272,30 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; int * ptr_gather_A_indices; int * ptr_gather_B_indices; int * ptr_scatter_D_indices; - int *semaphore; - // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - ptr_gather_A_indices(nullptr), - ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr), - semaphore(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), ptr_C(const_cast(args.ptr_C)), @@ -330,19 +303,18 @@ public: batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), - ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)), - semaphore(static_cast(workspace)) { + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) + {} - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + // Update input/output pointers ptr_A = const_cast(args.ptr_A); ptr_B = const_cast(args.ptr_B); ptr_C = const_cast(args.ptr_C); @@ -352,37 +324,28 @@ public: ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; - output_op = args.epilogue; - - semaphore = static_cast(workspace); - CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; typename Epilogue::SharedStorage epilogue; }; + public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmUniversal() { } - /// Determines whether kernel satisfies alignment static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) { - + cutlass::gemm::GemmCoord const & problem_size) + { CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); static int const kAlignmentA = (platform::is_same 1) { + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h new file mode 100644 index 00000000..b8bf3f80 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -0,0 +1,1126 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/barrier.h" +#include "cutlass/block_striped.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock mapping function +> +struct GemmUniversalStreamk { +public: + + + // + // Types and constants + // + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using ElementAccumulator = typename Mma::ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Number of workspace accumulation elements shared per per block + static int const kPeerAccumulators = Epilogue::kPeerAccumulators; + + /// Number of fragments per (thread) accumulator tile + static int const kAccumulatorFragments = Epilogue::kAccumulatorFragments; + + /// Number of numeric accumulation elements per fragment + static int const kAccumTileElements = sizeof(typename Mma::FragmentC) / sizeof(ElementAccumulator); + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Block-striped reduction utility + using BlockStripedReduceT = BlockStripedReduce; + + + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + + int sm_limit; /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + + + // + // Methods + // + + /// Default Constructor + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), + sm_limit(-1) + {} + + /// Constructor + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), sm_limit(sm_limit) + { + CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Constructor + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int sm_limit = -1 /// Carvout override: when the above are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), sm_limit(sm_limit) + { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + CUTLASS_TRACE_HOST("GemmUniversalStreamk::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const + { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + + /// Parameters structure + struct Params + { + public: + + // + // Data members + // + + ThreadblockSwizzle block_mapping; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + + void *barrier_workspace; + ElementAccumulator *partials_workspace; + + protected: + + // + // Host-only dispatch-utilities + // + + /// Pad the given allocation size up to the nearest cache line + static size_t cacheline_align_up(size_t size) + { + static const int CACHELINE_SIZE = 128; + return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE; + } + + /// Get the workspace size needed for barrier + size_t get_barrier_workspace_size() const + { + // For atomic reduction, each SK-block needs a synchronization flag. For parallel reduction, + // each reduction block needs its own synchronization flag. + int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region; + int num_flags = fast_max(sk_blocks, block_mapping.reduction_blocks); + + return cacheline_align_up(sizeof(typename Barrier::T) * num_flags); + } + + /// Get the workspace size needed for intermediate partial sums + size_t get_partials_workspace_size() const + { + // For atomic reduction, each SK-block can share one accumulator tile. For parallel reduction, + // each SK-block can share up to two accumulator tiles. + size_t tile_bytes_accumulators = sizeof(ElementAccumulator) * kPeerAccumulators * 2; + int sk_blocks = block_mapping.sk_regions * block_mapping.sk_blocks_per_region; + return cacheline_align_up(tile_bytes_accumulators * sk_blocks); + } + + + public: + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), + params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), + params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), + params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), + output_op(args.epilogue), + mode(args.mode), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + barrier_workspace(nullptr), + partials_workspace(nullptr) + { + // Number of SMs to make available for StreamK decomposition + int avail_sms = (args.sm_limit == -1) ? + device_sms : + fast_min(args.sm_limit, device_sms); + + // Initialize the block mapping structure + block_mapping = ThreadblockSwizzle( + typename ThreadblockSwizzle::template KernelTraits(), + args.mode, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count, + sm_occupancy, + avail_sms); + } + + + /// Returns the workspace size (in bytes) needed for these parameters + size_t get_workspace_size() const + { + return + get_barrier_workspace_size() + + get_partials_workspace_size(); + } + + + /// Assign and initialize the specified workspace buffer. Assumes + /// the memory allocated to workspace is at least as large as get_workspace_size(). + Status init_workspace( + void *workspace, + cudaStream_t stream = nullptr) + { + uint8_t *ptr = static_cast(workspace); + + // Establish partials workspace + partials_workspace = nullptr; + size_t partials_workspace_bytes = get_partials_workspace_size(); + if (partials_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + partials_workspace = reinterpret_cast(ptr); + ptr += partials_workspace_bytes; + } + + // Establish barrier workspace + barrier_workspace = nullptr; + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + if (barrier_workspace_bytes > 0) + { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + barrier_workspace = reinterpret_cast(ptr); + ptr += barrier_workspace_bytes; + } + + // Zero-initialize barrier workspace + if (barrier_workspace) + { + size_t barrier_workspace_bytes = get_barrier_workspace_size(); + + CUTLASS_TRACE_HOST(" Initialize " << barrier_workspace_bytes << " barrier bytes"); + + cudaError_t result = cudaMemsetAsync( + barrier_workspace, + 0, + barrier_workspace_bytes, + stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Returns the GEMM volume in thread block tiles + cutlass::gemm::GemmCoord get_tiled_shape() const + { + return block_mapping.tiled_shape; + } + + + /// Returns the total number of thread blocks to launch + int get_grid_blocks() const + { + dim3 grid_dims = get_grid_dims(); + return grid_dims.x * grid_dims.y * grid_dims.z; + } + + + /// Returns the grid extents in thread blocks to launch + dim3 get_grid_dims() const + { + return block_mapping.get_grid_dims(); + } + + + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { + CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()"); + + // Update input/output pointers + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + + output_op = args.epilogue; + } + + }; + + /// Tile work descriptor + struct TileWorkDesc + { + /// The linear tile index + int tile_idx; + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + cutlass::gemm::GemmCoord tiled_coord; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + int iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_begin; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + int k_end; + + /// The number of remaining MAC-iterations this threadblock will perform for this tile + int k_iters_remaining; + + // Whether this block will perform the first iteration of this tile + CUTLASS_DEVICE + bool tile_started() + { + return (k_begin == 0); + } + + // Whether this block will perform the last iteration of this tile + CUTLASS_DEVICE + bool tile_finished(Params const ¶ms) + { + return (k_end == params.block_mapping.problem_size.k()); + } + }; + + + /// Shared memory storage structure + union SharedStorage + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + +protected: + + // + // Data members + // + + /// GEMM problem parameters + Params const ¶ms; + + /// Shared storage reference + SharedStorage &shared_storage; + + /// ID within the threadblock + int thread_idx; + + /// ID of warp + int warp_idx; + + /// ID of each thread within a warp + int lane_idx; + + /// Block index + int block_idx; + + /// Threadblock scoped epilogue + Epilogue epilogue; + + +public: + + // + // Host-only dispatch API + // + + /// Determines whether the GEMM problem size satisfies this kernel's + /// alignment requirements + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) + { + CUTLASS_TRACE_HOST("GemmUniversalStreamk::can_implement()"); + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + + /// Determines whether the GEMM problem satisfies this kernel's + /// alignment requirements + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + +protected: + + // + // Device-only utility methods + // + + /// Iterator for fetching tile fragments from A + CUTLASS_DEVICE + typename Mma::IteratorA init_iterator_A(TileWorkDesc &tile_work) + { + // The input A matrix + ElementA *ptr_A = static_cast(params.ptr_A); + + // Update input pointers based on batched/array mode + if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += tile_work.tiled_coord.k() * params.batch_stride_A; + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[tile_work.tiled_coord.k()]; + } + + int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; + int m_end = params.block_mapping.problem_size.m(); + return Mma::IteratorA( + params.params_A, + ptr_A, + { m_end, tile_work.k_end }, + threadIdx.x, + { m_begin, tile_work.k_begin }); + + } + + + /// Iterator for fetching tile fragments from B + CUTLASS_DEVICE + typename Mma::IteratorB init_iterator_B(TileWorkDesc &tile_work) + { + // The input B matrix + ElementB *ptr_B = static_cast(params.ptr_B); + + // Update input pointers based on batched/array mode + if (params.mode == GemmUniversalMode::kBatched) { + ptr_B += tile_work.tiled_coord.k() * params.batch_stride_B; + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_B = static_cast(params.ptr_B)[tile_work.tiled_coord.k()]; + } + + int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; + int n_end = params.block_mapping.problem_size.n(); + return Mma::IteratorB( + params.params_B, + ptr_B, + { tile_work.k_end, n_end }, + threadIdx.x, + { tile_work.k_begin, n_begin }); + } + + + CUTLASS_DEVICE + void init_dp_tile_work( + TileWorkDesc &tile_work, + int tile_idx) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = tile_idx * params.block_mapping.iters_per_tile; + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = params.block_mapping.iters_per_tile; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = 0; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = params.block_mapping.problem_size.k(); + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + CUTLASS_DEVICE + void init_sk_tile_work( + TileWorkDesc &tile_work, + int tile_idx, + int block_iter_begin, + int block_iter_end) + { + // The linear tile index + tile_work.tile_idx = tile_idx; + + // The first global-scoped MAC-iteration for this tile + int tile_iter_begin = tile_idx * params.block_mapping.iters_per_tile; + + // The first global-scoped MAC-iteration this threadblock will perform for this tile + tile_work.iter_begin = max(block_iter_begin, tile_iter_begin); + + // The first tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_begin = tile_work.iter_begin - tile_iter_begin; + + // The last (one past) tile-scoped MAC-iteration this threadblock will perform for this tile + int k_iter_end = block_iter_end - tile_iter_begin; + + // The number of MAC-iterations this threadblock will perform for this tile + tile_work.k_iters_remaining = k_iter_end - k_iter_begin; + + // The starting index in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_begin = k_iter_begin * Mma::Shape::kK; + + // The ending index (one-past) in the k-domain for MAC-iterations this threadblock will perform for this tile + tile_work.k_end = min( + params.block_mapping.problem_size.k(), // extent of k domain + (k_iter_end * Mma::Shape::kK)); // extent of the threadblock's global iteration assignment + + // The location of this tile (in threadblock-tile coordinates) in the output matrix + tile_work.tiled_coord = params.block_mapping.get_tile_offset(tile_work.tile_idx); + } + + + /// Share accumulators with peers + CUTLASS_DEVICE + void share_accumulators(typename Mma::FragmentC const &accumulator_tile, int first_block_idx) + { + int block_tile_offset = first_block_idx * kPeerAccumulators; + + if (block_idx == first_block_idx) + { + // First peer initializes the workspace partials + BlockStripedReduceT::store(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx); + } + else + { + // Subsequent peers atomically accumulate into the workspace partials + if (ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) + { + // Non-deterministic reduction order: wait for the first peer to have initialized the partials before we add to them + Barrier::wait_lt(params.barrier_workspace, thread_idx, first_block_idx, 1); + } + else + { + // Turnstile reduction order: wait until the previous peer has written + int wait_count = block_idx - first_block_idx; + Barrier::wait_eq(params.barrier_workspace, thread_idx, first_block_idx, wait_count); + } + + // Perform reduction in workspace + BlockStripedReduceT::reduce(params.partials_workspace + block_tile_offset, accumulator_tile, thread_idx); + } + + // Signal our arrival + Barrier::arrive_inc(params.barrier_workspace, thread_idx, first_block_idx); + } + + + /// Acquire accumulators from peers + CUTLASS_DEVICE + void acquire_accumulators_atomic( + typename Mma::FragmentC &accumulator_tile, + int first_block_idx) + { + // Wait for arrival + int num_carry_in = block_idx - first_block_idx; + Barrier::wait_eq_reset(params.barrier_workspace, thread_idx, first_block_idx, num_carry_in); + + // Load and add peer-partials accumulator tile to local accumulator tile + int block_tile_offset = first_block_idx * kPeerAccumulators; + BlockStripedReduceT::load_add(accumulator_tile, params.partials_workspace + block_tile_offset, thread_idx); + } + + + /// Perform epilogue computations and output + CUTLASS_DEVICE + void do_epilogue( + TileWorkDesc &tile_work, + typename Mma::FragmentC &accumulator_tile) + { + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // Update pointers for batched/array mode(s) + if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C; + ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[tile_work.tiled_coord.k()]; + ptr_D = static_cast(params.ptr_D)[tile_work.tiled_coord.k()]; + } + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tile_work.tiled_coord.m() * Mma::Shape::kM, + tile_work.tiled_coord.n() * Mma::Shape::kN + ); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + EpilogueOutputOp(params.output_op), + iterator_D, + accumulator_tile, + iterator_C); + } + + + CUTLASS_DEVICE + void separate_reduction(int reduce_idx) + { + int peer_idx_begin, peer_idx_last, reduce_tile_idx, reduce_fragment_idx; + + // Reduce by sk-tile (every tile contributed to by one or more blocks) + reduce_tile_idx = reduce_idx / kAccumulatorFragments; + reduce_fragment_idx = reduce_idx % kAccumulatorFragments; + + int iter_tile_first = reduce_tile_idx * params.block_mapping.iters_per_tile; + int iter_tile_last = iter_tile_first + params.block_mapping.iters_per_tile - 1; + + peer_idx_begin = params.block_mapping.get_sk_block_idx(iter_tile_first); + peer_idx_last = params.block_mapping.get_sk_block_idx(iter_tile_last); + + // Wait for peers to complete + int peer_idx_end = peer_idx_last + 1; + int num_peers = peer_idx_end - peer_idx_begin; + Barrier::wait_eq_reset( + params.barrier_workspace, + thread_idx, + (reduce_tile_idx * kAccumulatorFragments) + reduce_fragment_idx, + num_peers); + + /// The location of this tile (in threadblock-tile coordinates) in the output matrix + GemmCoord tiled_coord = params.block_mapping.get_tile_offset(reduce_tile_idx); + + // Location of this tile in item-coords + MatrixCoord threadblock_item_begin( + tiled_coord.m() * Mma::Shape::kM, + tiled_coord.n() * Mma::Shape::kN + ); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // Update pointers for batched/array mode(s) + if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += tiled_coord.k() * params.batch_stride_C; + ptr_D += tiled_coord.k() * params.batch_stride_D; + } + if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[tiled_coord.k()]; + ptr_D = static_cast(params.ptr_D)[tiled_coord.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.block_mapping.problem_size.mn(), + thread_idx, + threadblock_item_begin); + + // Execute the epilogue operator to update the destination tensor. + epilogue.reduce( + peer_idx_begin, + peer_idx_end, + reduce_fragment_idx, + params.partials_workspace, + EpilogueOutputOp(params.output_op), + iterator_D, + iterator_C); + } + + + CUTLASS_DEVICE + void process_tile( + TileWorkDesc tile_work, + int dp_start_block_idx, + int block_iter_begin) + { + // Initialize input iterators + typename Mma::IteratorA iterator_A = init_iterator_A(tile_work); + typename Mma::IteratorB iterator_B = init_iterator_B(tile_work); + + // Initialize accumulators + typename Mma::FragmentC accumulator_tile; + accumulator_tile.clear(); + + // Perform this tile's range of multiply-accumulate (MAC) iterations + Mma mma( + shared_storage.main_loop, + thread_idx, + warp_idx, + lane_idx); + + mma(tile_work.k_iters_remaining, accumulator_tile, iterator_A, iterator_B, accumulator_tile); + + if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kAtomic) || + (params.block_mapping.reduction_blocks == 0) || + (block_idx >= dp_start_block_idx)) + { + // + // Cooperative SK peer reduction or DP block + // + + int first_block_idx = params.block_mapping.get_first_block_idx(tile_work.tile_idx, block_idx); + + if (!tile_work.tile_finished(params)) { + // Non "finishing" SK blocks must share their partial accumulator sums through global scratch workspace + share_accumulators(accumulator_tile, first_block_idx); + } + else + { + // DP blocks and "finishing" SK blocks must perform epilogue operations and write the output tile + if (!tile_work.tile_started()) + { + // A "finishing" SK block must first aggregate its accumulator partial sums with those shared by peer threadblocks + acquire_accumulators_atomic(accumulator_tile, first_block_idx); + } + + do_epilogue(tile_work, accumulator_tile); + } + } + else + { + // + // Separate peer reduction + // + + // Share accumulator partial sums with peer threadblock(s) through scratch workspace + epilogue.share(block_idx, params.partials_workspace, accumulator_tile, tile_work.tile_started()); + + // Signal arrival + Barrier::arrive_range_inc( + params.barrier_workspace, + thread_idx, + tile_work.tile_idx * kAccumulatorFragments, + kAccumulatorFragments); + } + } + + + /// Executes one GEMM + CUTLASS_DEVICE + void gemm() + { + // Initialize block's iteration range + int tile_idx, block_iter_begin, block_iters_remaining; + + int sk_padding_start_block_idx = params.block_mapping.sk_regions * params.block_mapping.sk_blocks_per_region; + int dp_start_block_idx = params.block_mapping.sk_waves * params.block_mapping.avail_sms; + int reduce_start_block_idx = dp_start_block_idx + params.block_mapping.dp_blocks; + int grid_padding_start_block_idx = reduce_start_block_idx + params.block_mapping.reduction_blocks; + + if (block_idx < sk_padding_start_block_idx) + { + // This is a SK block + int block_iter_end; + params.block_mapping.get_iter_extents(block_idx, block_iter_begin, block_iter_end); + block_iters_remaining = block_iter_end - block_iter_begin; + + tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1); + } + else if (block_idx < dp_start_block_idx) + { + // This is a filler block + return; + } + else if (block_idx < reduce_start_block_idx) + { + // This is a DP block + int dp_block_idx = block_idx - dp_start_block_idx; + int first_dp_tile = (params.block_mapping.cohort_raster) ? 0 : params.block_mapping.sk_tiles; + + // Blocks in first DP wave get configured number of tiles + tile_idx = first_dp_tile + dp_block_idx; + int tile_allottment = params.block_mapping.dp_first_wave_tiles; + + // Blocks in subsequent DP waves get 1 tile + if (dp_block_idx >= params.block_mapping.avail_sms) { + tile_allottment = 1; + tile_idx += (params.block_mapping.dp_first_wave_tiles - 1) * params.block_mapping.avail_sms; + } + + block_iter_begin = 0; + block_iters_remaining = params.block_mapping.iters_per_tile * tile_allottment; + } + else if ((ThreadblockSwizzle::kReductionStrategy == ThreadblockSwizzle::kMixed) && + (block_idx < grid_padding_start_block_idx)) + { + // This is a reduction threadblock + int reduce_block_idx = block_idx - reduce_start_block_idx; + separate_reduction(reduce_block_idx); + return; + } + else + { + // This is a filler block + return; + } + + // Iteration-processing loop body + CUTLASS_PRAGMA_NO_UNROLL + while (true) + { + // Initialize tile work descriptor + TileWorkDesc tile_work; + if (block_idx >= dp_start_block_idx) + { + init_dp_tile_work(tile_work, tile_idx); + + // DP blocks exit if out of bounds or overlap an SK tile (only possible during cohort rasterization, where dp_first_wave_tiles must be 1) + if ((tile_idx < params.block_mapping.sk_tiles) || + (tile_work.tiled_coord.m() >= params.block_mapping.tiled_shape.m()) || + (tile_work.tiled_coord.n() >= params.block_mapping.tiled_shape.n())) + { + break; + } + } + else + { + init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining); + } + + // Perform this block's share of work for this tile + process_tile(tile_work, dp_start_block_idx, block_iter_begin); + + // Update remaining work for this block + block_iters_remaining -= tile_work.k_iters_remaining; + if (block_iters_remaining == 0) { + // Done + break; + } + + // Continue to next tile + __syncthreads(); + + if (block_idx >= dp_start_block_idx) + { + // DP block consume their tiles at stride + tile_idx += params.block_mapping.avail_sms; + } + else + { + // SK blocks consume their tiles in backwards order + tile_idx--; + } + } + + } + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmUniversalStreamk op(params, shared_storage); + op(); + } + + + // Constructor + CUTLASS_DEVICE + GemmUniversalStreamk( + Params const ¶ms, + SharedStorage &shared_storage) + : + params(params), + shared_storage(shared_storage), + thread_idx(threadIdx.x), + warp_idx(__shfl_sync(0xffffffff, threadIdx.x / 32, 0)), // broadcast the warp_id computed by lane 0 to ensure dependent code + lane_idx(threadIdx.x % 32), + block_idx(params.block_mapping.get_block_idx()), + epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx) + {} + + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()() + { + // Do the GEMM + gemm(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index 264fde93..9e7dfc98 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -36,10 +36,12 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" +#include "cutlass/layout/layout.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/gemm/kernel/params_universal_base.h" #include "cutlass/trace.h" @@ -51,12 +53,21 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool IsSingleSource = Epilogue_::kIsSingleSource +> +struct GemmWithFusedEpilogue; + +// GemmWithFusedEpilogue with two sources template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_ ///! Threadblock swizzling function > -struct GemmWithFusedEpilogue { +struct GemmWithFusedEpilogue { public: using Mma = Mma_; @@ -101,21 +112,18 @@ public: // /// Argument structure - struct Arguments { + struct Arguments : UniversalArgumentsBase{ // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; void const * ptr_A; void const * ptr_B; - void const * ptr_C; + void const * ptr_C1; + void const * ptr_C2; void * ptr_D; void * ptr_Vector; @@ -123,14 +131,15 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; + int64_t batch_stride_C1; + int64_t batch_stride_C2; int64_t batch_stride_Vector; int64_t batch_stride_Tensor; typename LayoutA::Stride::Index lda; typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldc1; + typename LayoutC::Stride::Index ldc2; typename LayoutC::Stride::Index ldd; typename LayoutC::Stride::Index ldr; typename LayoutC::Stride::Index ldt; @@ -140,9 +149,12 @@ public: // Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C1(nullptr), + ptr_C2(nullptr), + ptr_D(nullptr) + {} /// constructs an arguments structure Arguments( @@ -152,37 +164,38 @@ public: typename EpilogueOutputOp::Params epilogue, void const * ptr_A, void const * ptr_B, - void const * ptr_C, + void const * ptr_C1, + void const * ptr_C2, void * ptr_D, void * ptr_Vector, void * ptr_Tensor, int64_t batch_stride_A, int64_t batch_stride_B, - int64_t batch_stride_C, + int64_t batch_stride_C1, + int64_t batch_stride_C2, int64_t batch_stride_D, int64_t batch_stride_Vector, int64_t batch_stride_Tensor, typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldc1, + typename LayoutC::Stride::Index ldc2, typename LayoutC::Stride::Index ldd, typename LayoutC::Stride::Index ldr, - typename LayoutC::Stride::Index ldt - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + typename LayoutC::Stride::Index ldt) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D), ptr_Vector(ptr_Vector), ptr_Tensor(ptr_Tensor), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - batch_stride_D(batch_stride_D), + batch_stride_C1(batch_stride_C1), + batch_stride_C2(batch_stride_C2), batch_stride_Vector(batch_stride_Vector), batch_stride_Tensor(batch_stride_Tensor), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) + lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt) { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); @@ -204,35 +217,44 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; + // + // Data members + // typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_C1; + typename Epilogue::OutputTileIterator::Params params_C2; typename Epilogue::OutputTileIterator::Params params_D; typename Epilogue::TensorTileIterator::Params params_Tensor; - typename EpilogueOutputOp::Params output_op; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - void * ptr_A; void * ptr_B; - void * ptr_C; + void * ptr_C1; + void * ptr_C2; void * ptr_D; - + void * ptr_Vector; typename LayoutC::Stride::Index ldr; @@ -240,78 +262,47 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; + int64_t batch_stride_C1; + int64_t batch_stride_C2; int64_t batch_stride_Vector; int64_t batch_stride_Tensor; - int *semaphore; - // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_Vector(nullptr), - ldr(0), - ptr_Tensor(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_Vector(0), - batch_stride_Tensor(0), - semaphore(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda), params_B(args.ldb), - params_C(args.ldc), + params_C1(args.ldc1), + params_C2(args.ldc2), params_D(args.ldd), params_Tensor(args.ldt), output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), + ptr_C1(const_cast(args.ptr_C1)), + ptr_C2(const_cast(args.ptr_C2)), ptr_D(args.ptr_D), - ptr_Vector(args.ptr_Vector), + ptr_Vector(args.ptr_Vector), ldr(args.ldr), ptr_Tensor(args.ptr_Tensor), - batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), + batch_stride_C1(args.batch_stride_C1), + batch_stride_C2(args.batch_stride_C2), batch_stride_Vector(args.batch_stride_Vector), - batch_stride_Tensor(args.batch_stride_Tensor), - - semaphore(static_cast(workspace)) { - + batch_stride_Tensor(args.batch_stride_Tensor) + { CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); @@ -319,31 +310,23 @@ public: CUTLASS_TRACE_HOST(" ldt: " << args.ldt); } + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { - + void update(Arguments const &args) + { ptr_A = const_cast(args.ptr_A); ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); + ptr_C1 = const_cast(args.ptr_C1); + ptr_C2 = const_cast(args.ptr_C2); ptr_D = args.ptr_D; ptr_Vector = args.ptr_Vector; ldr = args.ldr; ptr_Tensor = args.ptr_Tensor; - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; - batch_stride_Vector = args.batch_stride_Vector; - batch_stride_Tensor = args.batch_stride_Tensor; - output_op = args.epilogue; - semaphore = static_cast(workspace); - CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); @@ -351,6 +334,7 @@ public: } }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; @@ -360,12 +344,9 @@ public: public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmWithFusedEpilogue() { } - /// Determines whether kernel satisfies alignment static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { @@ -431,10 +412,20 @@ public: return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { +public: - return 0; + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmWithFusedEpilogue op; + op(params, shared_storage); } #define SPLIT_K_ENABLED 1 @@ -563,7 +554,8 @@ public: int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_C1 = static_cast(params.ptr_C1); + ElementC *ptr_C2 = static_cast(params.ptr_C2); ElementC *ptr_D = static_cast(params.ptr_D); typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); @@ -581,7 +573,720 @@ public: if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { - // Tile iterator loading from source tensor. + // Tile iterators loading from source tensors. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C2( + params.params_C2, + ptr_C2, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + iterator_D, + accumulators, + iterator_C1, + iterator_C2, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + return; + } + + // + // Slower path when split-K or batching is needed + // + + + #if SPLIT_K_ENABLED + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + if (ptr_C2) { + ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2; + } + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + if (ptr_Tensor) { + ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + } + if (ptr_Vector) { + ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; + } + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C1 = static_cast(params.ptr_C1)[threadblock_tile_offset.k()]; + if (ptr_C2) { + ptr_C2 = static_cast(params.ptr_C2)[threadblock_tile_offset.k()]; + } + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + if (ptr_Tensor) { + ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; + } + if (ptr_Vector) { + ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; + } + } + #endif + + // Tile iterators loading from source tensors. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + typename Epilogue::OutputTileIterator iterator_C2( + params.params_C2, + ptr_C2, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + #if SPLIT_K_ENABLED + // Wait on the semaphore - this latency may have been covered by iterator construction + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + #endif + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C1, + iterator_C2, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + // + // Release the semaphore + // + + #if SPLIT_K_ENABLED + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + #endif + } +}; + +// GemmWithFusedEpilogue with one source +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithFusedEpilogue { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments : UniversalArgumentsBase + { + // + // Data members + // + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + void * ptr_Vector; + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldd; + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + Arguments(): + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr) + {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void * ptr_Vector, + void * ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) + { + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << this->ldt); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; + + // + // Data members + // + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params( + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + params_Tensor(args.ldt), + output_op(args.epilogue), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor) + { + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << args.ldt); + } + + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + CUTLASS_HOST_DEVICE + void update(Arguments const &args) + { + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_Vector = args.ptr_Vector; + ldr = args.ldr; + ptr_Tensor = args.ptr_Tensor; + + output_op = args.epilogue; + + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + } + }; + + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + +public: + + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmWithFusedEpilogue op; + op(params, shared_storage); + } + + #define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + + #if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + #endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // + // Fetch pointers based on mode. + // + + // + // Special path when split-K not enabled. + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { + + // Tile iterators loading from source tensors. typename Epilogue::OutputTileIterator iterator_C( params.params_C, ptr_C, @@ -679,7 +1384,7 @@ public: } #endif - // Tile iterator loading from source tensor. + // Tile iterators loading from source tensors. typename Epilogue::OutputTileIterator iterator_C( params.params_C, ptr_C, diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h deleted file mode 100644 index 4d62b914..00000000 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue_v2.h +++ /dev/null @@ -1,854 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Gemm kernel with fused reduction operation. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" -#include "cutlass/semaphore.h" - -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_ ///! Threadblock swizzling function -> -struct GemmWithFusedEpilogueV2 { -public: - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = const_max( - 128 / sizeof_bits::value, - 128 / sizeof_bits::value - ); - - // - // Structures - // - - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - - typename EpilogueOutputOp::Params epilogue; - - void const * ptr_A; - void const * ptr_B; - void const * ptr_C1; - void const * ptr_C2; - void * ptr_D; - - void * ptr_Vector; - void * ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C1; - int64_t batch_stride_C2; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc1; - typename LayoutC::Stride::Index ldc2; - typename LayoutC::Stride::Index ldd; - typename LayoutC::Stride::Index ldr; - typename LayoutC::Stride::Index ldt; - - // - // Methods - // - - Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C1(nullptr), ptr_C2(nullptr), ptr_D(nullptr) { } - - /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C1, - void const * ptr_C2, - void * ptr_D, - void * ptr_Vector, - void * ptr_Tensor, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C1, - int64_t batch_stride_C2, - int64_t batch_stride_D, - int64_t batch_stride_Vector, - int64_t batch_stride_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc1, - typename LayoutC::Stride::Index ldc2, - typename LayoutC::Stride::Index ldd, - typename LayoutC::Stride::Index ldr, - typename LayoutC::Stride::Index ldt - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C1(ptr_C1), ptr_C2(ptr_C2), ptr_D(ptr_D), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C1(batch_stride_C1), - batch_stride_C2(batch_stride_C2), - batch_stride_D(batch_stride_D), - batch_stride_Vector(batch_stride_Vector), - batch_stride_Tensor(batch_stride_Tensor), - lda(lda), ldb(ldb), ldc1(ldc1), ldc2(ldc2), ldd(ldd), ldr(ldr), ldt(ldt) - { - CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - CUTLASS_TRACE_HOST(" ldt: " << this->ldt); - } - - /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C, - void * ptr_D, - void * ptr_Vector, - void * ptr_Tensor, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - int64_t batch_stride_Vector, - int64_t batch_stride_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd, - typename LayoutC::Stride::Index ldr, - typename LayoutC::Stride::Index ldt - ): Arguments( - mode, problem_size, batch_count, epilogue, - ptr_A, ptr_B, ptr_C, nullptr, ptr_D, ptr_Vector, ptr_Tensor, - batch_stride_A, batch_stride_B, batch_stride_C, 0, batch_stride_D, - batch_stride_Vector, batch_stride_Tensor, - lda, ldb, ldc, 0, ldd, ldr, ldt) {} - - /// Returns arguments for the transposed problem - Arguments transposed_problem() const { - Arguments args(*this); - - std::swap(args.problem_size.m(), args.problem_size.n()); - std::swap(args.ptr_A, args.ptr_B); - std::swap(args.lda, args.ldb); - std::swap(args.batch_stride_A, args.batch_stride_B); - - return args; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C1; - typename Epilogue::OutputTileIterator::Params params_C2; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; - - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void * ptr_A; - void * ptr_B; - void * ptr_C1; - void * ptr_C2; - void * ptr_D; - - void * ptr_Vector; - typename LayoutC::Stride::Index ldr; - - void * ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C1; - int64_t batch_stride_C2; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - int *semaphore; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C1(0), - params_C2(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C1(nullptr), - ptr_C2(nullptr), - ptr_D(nullptr), - ptr_Vector(nullptr), - ldr(0), - ptr_Tensor(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C1(0), - batch_stride_C2(0), - batch_stride_D(0), - batch_stride_Vector(0), - batch_stride_Tensor(0), - semaphore(nullptr) { } - - CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.lda), - params_B(args.ldb), - params_C1(args.ldc1), - params_C2(args.ldc2), - params_D(args.ldd), - params_Tensor(args.ldt), - output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C1(const_cast(args.ptr_C1)), - ptr_C2(const_cast(args.ptr_C2)), - ptr_D(args.ptr_D), - ptr_Vector(args.ptr_Vector), - ldr(args.ldr), - ptr_Tensor(args.ptr_Tensor), - - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C1(args.batch_stride_C1), - batch_stride_C2(args.batch_stride_C2), - batch_stride_D(args.batch_stride_D), - batch_stride_Vector(args.batch_stride_Vector), - batch_stride_Tensor(args.batch_stride_Tensor), - - semaphore(static_cast(workspace)) { - - CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - CUTLASS_TRACE_HOST(" ldt: " << args.ldt); - } - - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { - - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C1 = const_cast(args.ptr_C1); - ptr_C2 = const_cast(args.ptr_C2); - ptr_D = args.ptr_D; - - ptr_Vector = args.ptr_Vector; - ldr = args.ldr; - ptr_Tensor = args.ptr_Tensor; - - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C1 = args.batch_stride_C1; - batch_stride_C2 = args.batch_stride_C2; - batch_stride_D = args.batch_stride_D; - batch_stride_Vector = args.batch_stride_Vector; - batch_stride_Tensor = args.batch_stride_Tensor; - - output_op = args.epilogue; - - semaphore = static_cast(workspace); - - CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: - - // - // Methods - // - - CUTLASS_DEVICE - GemmWithFusedEpilogueV2() { } - - /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) { - - CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::can_implement()"); - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (platform::is_same::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } else if (platform::is_same::value) { - isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value - || platform::is_same>::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (platform::is_same::value) { - isBMisaligned = problem_size.n() % kAlignmentB; - } else if (platform::is_same::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value - || platform::is_same>::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (platform::is_same::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } else if (platform::is_same::value) { - isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value - || platform::is_same>::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; - } - - static Status can_implement(Arguments const &args) { - return can_implement(args.problem_size); - } - - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - - #define SPLIT_K_ENABLED 1 - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA *ptr_A = static_cast(params.ptr_A); - ElementB *ptr_B = static_cast(params.ptr_B); - - - #if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } - #endif - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{ - offset_k, - threadblock_tile_offset.n() * Mma::Shape::kN - }; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC *ptr_C1 = static_cast(params.ptr_C1); - ElementC *ptr_C2 = static_cast(params.ptr_C2); - ElementC *ptr_D = static_cast(params.ptr_D); - typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector *ptr_Vector = - static_cast(params.ptr_Vector); - - // - // Fetch pointers based on mode. - // - - // - // Special path when split-K not enabled. - // - - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C1( - params.params_C1, - ptr_C1, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - typename Epilogue::OutputTileIterator iterator_C2( - params.params_C2, - ptr_C2, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - ptr_Vector, - iterator_D, - accumulators, - iterator_C1, - iterator_C2, - tensor_iterator, - params.problem_size.mn(), - threadblock_offset); - - return; - } - - // - // Slower path when split-K or batching is needed - // - - - #if SPLIT_K_ENABLED - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - if (params.mode == GemmUniversalMode::kGemm) { - - // If performing a reduction via split-K, fetch the initial synchronization - if (params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - } - else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kBatched) { - ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; - if (ptr_C2) { - ptr_C2 += threadblock_tile_offset.k() * params.batch_stride_C2; - } - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - if (ptr_Tensor) { - ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; - } - if (ptr_Vector) { - ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; - } - } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_C1 = static_cast(params.ptr_C1)[threadblock_tile_offset.k()]; - if (ptr_C2) { - ptr_C2 = static_cast(params.ptr_C2)[threadblock_tile_offset.k()]; - } - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; - if (ptr_Tensor) { - ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; - } - if (ptr_Vector) { - ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; - } - } - #endif - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C1( - params.params_C1, - ptr_C1, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - typename Epilogue::OutputTileIterator iterator_C2( - params.params_C2, - ptr_C2, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); - - #if SPLIT_K_ENABLED - // Wait on the semaphore - this latency may have been covered by iterator construction - if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C1 = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - - } - #endif - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - // Only the final block uses Vector - ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Vector, - iterator_D, - accumulators, - iterator_C1, - iterator_C2, - tensor_iterator, - params.problem_size.mn(), - threadblock_offset); - - // - // Release the semaphore - // - - #if SPLIT_K_ENABLED - if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - #endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index 93e5ed43..6d69735c 100644 --- a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -42,6 +42,7 @@ #include "cutlass/complex.h" #include "cutlass/semaphore.h" #include "cutlass/layout/pitch_linear.h" +#include "cutlass/gemm/kernel/params_universal_base.h" #include "cutlass/trace.h" @@ -105,16 +106,12 @@ public: // /// Argument structure - struct Arguments { - + struct Arguments : UniversalArgumentsBase + { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueOutputOp::Params epilogue; void const * ptr_A; @@ -126,7 +123,6 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; int64_t batch_stride_gemm_k_reduction; typename LayoutA::Stride::Index lda; @@ -138,11 +134,14 @@ public: // // Methods // - - Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_gemm_k_reduction(nullptr) { } + + Arguments() : + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_gemm_k_reduction(nullptr) + {} /// constructs an arguments structure Arguments( @@ -164,23 +163,21 @@ public: typename LayoutB::Stride::Index ldb, typename LayoutC::Stride::Index ldc, typename LayoutC::Stride::Index ldd, - typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_gemm_k_reduction(ptr_gemm_k_reduction), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), batch_stride_gemm_k_reduction(batch_stride_gemm_k_reduction), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_gemm_k_reduction(ld_gemm_k_reduction) { - + typename LayoutGemmKReduction::Stride::Index ld_gemm_k_reduction) + : + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), ptr_gemm_k_reduction(ptr_gemm_k_reduction), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_gemm_k_reduction(batch_stride_gemm_k_reduction), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ld_gemm_k_reduction(ld_gemm_k_reduction) + { CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); - } + } /// Returns arguments for the transposed problem Arguments transposed_problem() const { Arguments args(*this); - + std::swap(args.problem_size.m(), args.problem_size.n()); std::swap(args.ptr_A, args.ptr_B); std::swap(args.lda, args.ldb); @@ -190,16 +187,29 @@ public: } }; + // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure - struct Params { + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> + { + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; + // + // Data members + // typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; @@ -208,10 +218,6 @@ public: typename EpilogueOutputOp::Params output_op; - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - void * ptr_A; void * ptr_B; void * ptr_C; @@ -221,97 +227,86 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; int64_t batch_stride_gemm_k_reduction; - int *semaphore; - // - // Methods + // Host dispatch API // - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_gemm_k_reduction(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_gemm_k_reduction(0), - semaphore(nullptr) { } + /// Default constructor + Params() = default; - CUTLASS_HOST_DEVICE + /// Constructor Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + Arguments const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda), params_B(args.ldb), params_C(args.ldc), params_D(args.ldd), output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), ptr_C(const_cast(args.ptr_C)), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), batch_stride_gemm_k_reduction(args.batch_stride_gemm_k_reduction), - semaphore(static_cast(workspace)) { + ptr_D(args.ptr_D), + ptr_gemm_k_reduction(args.ptr_gemm_k_reduction) + {} - CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << problem_size); + /// Assign and initialize the specified workspace buffer. Assumes + /// the memory allocated to workspace is at least as large as get_workspace_size(). + Status init_workspace( + void *workspace, + cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << this->problem_size); - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (this->mode == GemmUniversalMode::kGemmSplitKParallel) { ptr_D = workspace; ptr_gemm_k_reduction = static_cast(workspace) - + sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } else { - ptr_D = args.ptr_D; - ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; + + sizeof(ElementC) * size_t(this->batch_stride_D) * size_t(this->grid_tiled_shape.k()); + + return Status::kSuccess; } + + return ParamsBase::init_workspace(workspace, stream); } - CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { + /// Returns the workspace size (in bytes) needed for this problem geometry + size_t get_workspace_size() const + { + size_t workspace_bytes = ParamsBase::get_workspace_size(); + if (this->mode == GemmUniversalMode::kGemmSplitKParallel) + { + // Split-K parallel always requires a temporary workspace + workspace_bytes += + sizeof(ElementC) * + size_t(batch_stride_gemm_k_reduction) * + size_t(this->grid_tiled_shape.k()); + } + + return workspace_bytes; + } + + /// Lightweight update given a subset of arguments. Problem geometry is assumed + /// to remain the same. + void update(Arguments const &args) + { ptr_A = const_cast(args.ptr_A); ptr_B = const_cast(args.ptr_B); ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; - batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction; - output_op = args.epilogue; - semaphore = static_cast(workspace); CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } }; @@ -322,15 +317,13 @@ public: typename Epilogue::SharedStorage epilogue; }; + public: // - // Methods + // Host dispatch API // - CUTLASS_DEVICE - GemmWithKReduction() { } - /// Determines whether kernel satisfies alignment static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { @@ -410,26 +403,29 @@ public: return Status::kSuccess; } + static Status can_implement(Arguments const &args) { return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - size_t workspace_bytes = 0; - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { - - // Split-K parallel always requires a temporary workspace - workspace_bytes = - sizeof(ElementC) * - size_t(args.batch_stride_gemm_k_reduction) * - size_t(grid_tiled_shape.k()); - } +public: - return workspace_bytes; + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke( + Params const ¶ms, + SharedStorage &shared_storage) + { + GemmWithKReduction op; + op(params, shared_storage); } - + + /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/include/cutlass/gemm/kernel/grouped_problem_visitor.h index c5321153..7e631d69 100644 --- a/include/cutlass/gemm/kernel/grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -126,11 +126,7 @@ struct BaseGroupedProblemVisitor { /// Get the grid shape CUTLASS_HOST_DEVICE static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { - - return cutlass::gemm::GemmCoord( - ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), - 1); + return ProblemSizeHelper::grid_shape(problem); } /// Gets the global tile index @@ -346,7 +342,7 @@ struct GroupedProblemVisitor : public BaseGroupedProblemVisitor { static_assert(PrefetchTileCount > 0, - "GroupedProblemVisitor with GroupScheduleMode `kHost` currently requires prefetching to shared memory"); + "GroupedProblemVisitor with GroupScheduleMode `kHostPrecompute` currently requires prefetching to shared memory"); using Base = BaseGroupedProblemVisitor; using Params = typename Base::Params; diff --git a/include/cutlass/gemm/kernel/params_universal_base.h b/include/cutlass/gemm/kernel/params_universal_base.h new file mode 100644 index 00000000..1f56a12a --- /dev/null +++ b/include/cutlass/gemm/kernel/params_universal_base.h @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base functionality for common types of universal GEMM kernel parameters +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/gemm.h" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Argument structure +struct UniversalArgumentsBase +{ + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + int64_t batch_stride_D; + + // + // Methods + // + + UniversalArgumentsBase() : + mode(GemmUniversalMode::kGemm), + batch_count(1), + batch_stride_D(0) + {} + + /// constructs an arguments structure + UniversalArgumentsBase( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + int64_t batch_stride_D) + : + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + batch_stride_D(batch_stride_D) + { + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } +}; + + +/// Parameters structure +template < + typename ThreadblockSwizzle, + typename ThreadblockShape, + typename ElementA, + typename ElementB, + typename ElementC> +struct UniversalParamsBase +{ + // + // Data members + // + + GemmCoord problem_size; + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + int64_t batch_stride_D; + + int *semaphore; + + + // + // Host dispatch API + // + + /// Default constructor + UniversalParamsBase() = default; + + + /// Constructor + UniversalParamsBase( + UniversalArgumentsBase const &args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : + problem_size(args.problem_size), + mode(args.mode), + batch_count(args.batch_count), + batch_stride_D(args.batch_stride_D), + semaphore(nullptr) + { + ThreadblockSwizzle swizzle; + + // Get GEMM volume in thread block tiles + grid_tiled_shape = swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + swizzle_log_tile = swizzle.get_log_tile(grid_tiled_shape); + + // Determine extent of K-dimension assigned to each block + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + + /// Returns the workspace size (in bytes) needed for this problem geometry + size_t get_workspace_size() const + { + size_t workspace_bytes = 0; + if (mode == GemmUniversalMode::kGemmSplitKParallel) + { + // Split-K parallel always requires a temporary workspace + workspace_bytes = + sizeof(ElementC) * + size_t(batch_stride_D) * + size_t(grid_tiled_shape.k()); + } + else if (mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) + { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + + /// Assign and initialize the specified workspace buffer. Assumes + /// the memory allocated to workspace is at least as large as get_workspace_size(). + Status init_workspace( + void *workspace, + cudaStream_t stream = nullptr) + { + semaphore = static_cast(workspace); + // Zero-initialize entire workspace + if (semaphore) + { + size_t workspace_bytes = get_workspace_size(); + + CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); + + cudaError_t result = cudaMemsetAsync( + semaphore, + 0, + workspace_bytes, + stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + + /// Returns the GEMM volume in thread block tiles + GemmCoord get_tiled_shape() const + { + return grid_tiled_shape; + } + + + /// Returns the total number of thread blocks to launch + int get_grid_blocks() const + { + dim3 grid_dims = get_grid_dims(); + return grid_dims.x * grid_dims.y * grid_dims.z; + } + + + /// Returns the grid extents in thread blocks to launch + dim3 get_grid_dims() const + { + return ThreadblockSwizzle().get_grid_shape(grid_tiled_shape); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 91e7767c..db4eab0c 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -384,13 +384,6 @@ public: return Status::kSuccess; } - static size_t get_extra_workspace_size( - Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h index 70b00b62..e0e20936 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -280,6 +280,14 @@ template struct Rank2KGroupedProblemSizeHelper { using OffsetHelper = Rank2KGroupedProblemVisitorOffsetHelper; + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), + 1); + } + CUTLASS_HOST_DEVICE static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { // Return the number of tiles at or below the diagonal (or at and above diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index 25336424..82979a92 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -144,7 +144,10 @@ struct MmaGeneric { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) - if (kMultipleOf2 && platform::is_same::value && platform::is_same::value && platform::is_same::value) { + if (kMultipleOf2 && + platform::is_same::value && + platform::is_same::value && + platform::is_same::value) { //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. CUTLASS_PRAGMA_UNROLL @@ -250,6 +253,184 @@ struct MmaGeneric { }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Matrix multiply-add operation - assumes operand B is not changing +struct MmaComplexF32_Column { + + using Shape = gemm::GemmShape<1, 1, 1>; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + } +}; + +/// Matrix multiply-add operation - assumes operand A is not changing +struct MmaComplexF32_Corner { + + using Shape = gemm::GemmShape<1, 1, 1>; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gemplate that handles all packed matrix layouts +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Layout of A matrix (concept: layout::MapFunc) + typename LayoutA_, + /// Layout of B matrix (concept: layout::MapFunc) + typename LayoutB_, + /// Layout of C matrix (concept: layout::MapFunc) + typename LayoutC_ +> +struct MmaGeneric< + Shape_, + complex, + LayoutA_, + complex, + LayoutB_, + complex, + LayoutC_, + arch::OpMultiplyAdd> { + + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = complex; + + /// Layout of A matrix (concept: layout::MapFunc) + using LayoutA = LayoutA_; + + /// Data type of operand B + using ElementB = complex; + + /// Layout of B matrix (concept: layout::MapFunc) + using LayoutB = LayoutB_; + + /// Element type of operand C + using ElementC = complex; + + /// Layout of C matrix (concept: layout::MapFunc) + using LayoutC = LayoutC_; + + /// Underlying mathematical operator + using Operator = arch::OpMultiplyAdd; + + /// A operand storage + using FragmentA = Array; + + /// B operand storage + using FragmentB = Array; + + /// C operand storage + using FragmentC = Array; + + /// Instruction + using MmaOp = arch::Mma< + gemm::GemmShape<1,1,1>, + 1, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + Operator>; + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + + TensorRef a_ref( + reinterpret_cast(&A), LayoutA::packed({Shape::kM, Shape::kK})); + + TensorRef b_ref( + reinterpret_cast(&B), LayoutB::packed({Shape::kK, Shape::kN})); + + TensorRef d_ref( + reinterpret_cast(&D), LayoutC::packed(make_Coord(Shape::kM, Shape::kN))); + + detail::MmaComplexF32_Column mma_column; + detail::MmaComplexF32_Corner mma_corner; + + // Copy accumulators + D = C; + + // Compute matrix product + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < Shape::kK; ++k) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Shape::kN; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { + + int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; + + MatrixCoord mn(m_serpentine, n); + MatrixCoord mk(m_serpentine, k); + MatrixCoord kn(k, n); + + Array d; + Array a; + Array b; + + d[0] = d_ref.at(mn); + a[0] = a_ref.at(mk); + b[0] = b_ref.at(kn); + + if ((m == 0 && n) || m == Shape::kM - 1) { + mma_corner(d, a, b, d); + } + else { + mma_column(d, a, b, d); + } + + d_ref.at(mn) = d[0]; + } + } + } + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Gemplate that handles conventional layouts for FFMA and DFMA GEMM diff --git a/include/cutlass/gemm/threadblock/default_ell_mma.h b/include/cutlass/gemm/threadblock/default_ell_mma.h new file mode 100644 index 00000000..3dae8564 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_ell_mma.h @@ -0,0 +1,734 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default template for a Blocked-Ell MMA. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +#include "cutlass/gemm/threadblock/ell_mma_pipelined.h" +#include "cutlass/gemm/threadblock/ell_mma_multistage.h" +#include "cutlass/transform/threadblock/ell_predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false + > +struct DefaultEllMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass Simt) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassSimt, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator + > +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator + > +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, float, LayoutA, float, + LayoutB, float, layout::RowMajor, arch::OpClassTensorOp, 2, + arch::OpMultiplyAddFastF16>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + float, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + float, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, float, + layout::RowMajor, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Number of Interleaved K + int InterleavedK> +struct DefaultEllMma, OperatorClass, + ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, + Operator, true> { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore::IteratorThreadMapB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + layout::ColumnMajorInterleaved, + typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator + > +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, Operator>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator + > +struct DefaultEllMma { + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Number of Interleaved K + int InterleavedK> +struct DefaultEllMma, OperatorClass, + ArchTag, ThreadblockShape, WarpShape, InstructionShape, + Stages, Operator, true> { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for SIMT IDP4A Kernels +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Operation performed by GEMM + typename Operator, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape> +struct DefaultEllMma, 2, + Operator, false> { + using InstructionShape = GemmShape<1, 1, 4>; + using ElementA = int8_t; + using ElementB = int8_t; + using OperatorClass = arch::OpClassSimt; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +/// Specialization for Wmma TensorOp operator with 2 staged pipeline +template < + ///< Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, LayoutC, + arch::OpClassWmmaTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::EllMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + LayoutC, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for Wmma TensorOp operator with 1 staged pipeline +template < + ///< Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultEllMma { + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, LayoutC, + arch::OpClassWmmaTensorOp, 1, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::EllPredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped singlestage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator, + LayoutC, typename MmaCore::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// +#endif //CUTLASS_ARCH_WMMA_ENABLED + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h index eb44b113..a31dc899 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -36,6 +36,9 @@ Partial specializations for threadblock::Mma operations targeting TensorOp instructions. + + SM80 Multi stage kernel expects stage number to be larger or equal to 3 + to use asyncronous copy. */ #pragma once diff --git a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h index 0fc68359..2e05e20b 100644 --- a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h @@ -600,7 +600,7 @@ struct DefaultMultistageMmaComplexCore< /// A: column-major /// B: column-major /// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: MMA.1688.F32.TF32 +/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 /// /// This uses the default warp-level operator given tile sizes template < @@ -730,7 +730,7 @@ struct DefaultMultistageMmaComplexCore< /// A: column-major /// B: row-major /// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: MMA.1688.F32.TF32 +/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 /// /// This uses the default warp-level operator given tile sizes template < @@ -861,7 +861,7 @@ struct DefaultMultistageMmaComplexCore< /// A: row-major /// B: column-major /// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: MMA.1688.F32.TF32 +/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 /// /// This uses the default warp-level operator given tile sizes template < @@ -992,7 +992,7 @@ struct DefaultMultistageMmaComplexCore< /// A: row-major /// B: row-major /// Operator: arch::OpMultiplyAddComplex -/// Math Instruction: MMA.1688.F32.TF32 +/// Math Instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 /// /// This uses the default warp-level operator given tile sizes template < @@ -1118,10 +1118,10 @@ struct DefaultMultistageMmaComplexCore< //////////////////////////////////////////////////////////////////////////////// -/// Partial specialization for complex double-precision +/// Partial specialization for complex SIMT operation /// /// A: column-major -/// B: row-major +/// B: column-major /// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex /// /// This uses the default warp-level operator given tile sizes @@ -1267,15 +1267,18 @@ struct DefaultMultistageMmaComplexCore< >; using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + 1, /// 1 partition along K dimension + kTransformA, /// Transform for A + kTransformB /// Transform for B + >; /// Used for partial specialization /// Policy used to define MmaPipelined using MmaPolicy = MmaPolicy< @@ -1285,7 +1288,7 @@ struct DefaultMultistageMmaComplexCore< WarpCount::kK>; }; -/// Partial specialization for complex double-precision +/// Partial specialization for complex SIMT operation /// /// A: column-major /// B: row-major @@ -1431,15 +1434,18 @@ struct DefaultMultistageMmaComplexCore< >; using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + 1, /// 1 partition along K dimension + kTransformA, /// Transform for A + kTransformB /// Transform for B + >; /// Used for partial specialization /// Policy used to define MmaPipelined using MmaPolicy = MmaPolicy< @@ -1449,10 +1455,10 @@ struct DefaultMultistageMmaComplexCore< WarpCount::kK>; }; -/// Partial specialization for complex double-precision +/// Partial specialization for complex SIMT operation /// -/// A: column-major -/// B: row-major +/// A: row-major +/// B: column-major /// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex /// /// This uses the default warp-level operator given tile sizes @@ -1601,15 +1607,18 @@ struct DefaultMultistageMmaComplexCore< >; using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + 1, /// 1 partition along K dimension + kTransformA, /// Transform for A + kTransformB /// Transform for B + >; /// Used for partial specialization /// Policy used to define MmaPipelined using MmaPolicy = MmaPolicy< @@ -1619,9 +1628,9 @@ struct DefaultMultistageMmaComplexCore< WarpCount::kK>; }; -/// Partial specialization for complex double-precision +/// Partial specialization for complex SIMT operation /// -/// A: column-major +/// A: row-major /// B: row-major /// Operator: arch::OpMultiplyAddComplex or arch::OpMultiplyGaussianComplex /// @@ -1768,15 +1777,18 @@ struct DefaultMultistageMmaComplexCore< >; using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< - WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 - ElementA, /// Data type of A elements - SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) - ElementB, /// Data type of B elements - SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) - ElementC, /// Element type of C matrix - LayoutC, /// Layout of C matrix (concept: MatrixLayout) - Policy /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - >; /// Used for partial specialization + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> 128, 128, 8 + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + 1, /// 1 partition along K dimension + kTransformA, /// Transform for A + kTransformB /// Transform for B + >; /// Used for partial specialization /// Policy used to define MmaPipelined using MmaPolicy = MmaPolicy< diff --git a/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/include/cutlass/gemm/threadblock/ell_mma_multistage.h new file mode 100644 index 00000000..15155fe5 --- /dev/null +++ b/include/cutlass/gemm/threadblock/ell_mma_multistage.h @@ -0,0 +1,642 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Blocked-Ell MMA. +*/ + +#pragma once + + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class EllMmaMultistage : + public MmaBase { +public: + ///< Base class + using Base = MmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + EllMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, EllIterator &ell_iter, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + bool is_valid = iterator_A.valid(); + + if (!is_A_sparse){ + if (is_offset_constant){ + auto ell_offset = ell_iter.get_offset_fast(); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + } else { + int k_offset = iterator_A.get_k(); + auto ell_offset = ell_iter.get_offset(k_offset); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + } + } + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + bool is_valid = iterator_B.valid(); + + if (is_A_sparse){ + if (is_offset_constant){ + auto ell_offset = ell_iter.get_offset_fast(); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + } else { + int k_offset = iterator_B.get_k(); + auto ell_offset = ell_iter.get_offset(k_offset); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + } + } + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum, + EllIterator &ell_iterator + ) { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + auto gmem_ptr = iterator_A.get(); + bool is_valid = iterator_A.valid(); + + if (!is_A_sparse){ + if (is_offset_constant){ + auto ell_offset = ell_iterator.get_offset_fast(); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + } else { + int k_offset = iterator_A.get_k(); + auto ell_offset = ell_iterator.get_offset(k_offset); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + } + } + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + auto gmem_ptr = iterator_B.get(); + bool is_valid = iterator_B.valid(); + + if (is_A_sparse){ + if (is_offset_constant){ + auto ell_offset = ell_iterator.get_offset_fast(); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + } else { + int k_offset = iterator_B.get_k(); + auto ell_offset = ell_iterator.get_offset(k_offset); + is_valid = is_valid && (ell_offset >= 0); + gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + } + } + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + ++ell_iterator; + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + if (is_A_sparse){ + iterator_A.ell_add_mask(ell_iterator.get_blocksize()); + } + else { + iterator_B.ell_add_mask(ell_iterator.get_blocksize()); + } + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum + ); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, ell_iterator, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, iterator_B, ell_iterator, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + ++ell_iterator; + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum = plus_accum(accum, tmp_accum); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/ell_mma_pipelined.h b/include/cutlass/gemm/threadblock/ell_mma_pipelined.h new file mode 100644 index 00000000..7bffa4ac --- /dev/null +++ b/include/cutlass/gemm/threadblock/ell_mma_pipelined.h @@ -0,0 +1,376 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Blocked-Ell MMA. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class EllMmaPipelined : public MmaBase { +public: + + ///< Base class + using Base = MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for EllMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "EllMmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + EllMmaPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + EllIterator &ell_iterator, + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // load sparse matrix + if (is_A_sparse){ + iterator_A.load(tb_frag_A); + } else { + iterator_B.load(tb_frag_B); + } + + // load dense matrix + if (is_offset_constant){ + if (is_A_sparse){ + iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); + } else { + iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); + } + } else { + if (is_A_sparse){ + iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); + } else { + iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); + } + } + + ++iterator_A; + ++iterator_B; + ++ell_iterator; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + if (is_A_sparse){ + iterator_A.ell_add_mask(ell_iterator.get_blocksize()); + } + else { + iterator_B.ell_add_mask(ell_iterator.get_blocksize()); + } + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + // load sparse matrix + if (is_A_sparse){ + iterator_A.load(tb_frag_A); + } else { + iterator_B.load(tb_frag_B); + } + + // load dense matrix + if (is_offset_constant){ + if (is_A_sparse){ + iterator_B.load_with_ell_index_fast(tb_frag_B, ell_iterator); + } else { + iterator_A.load_with_ell_index_fast(tb_frag_A, ell_iterator); + } + } else { + if (is_A_sparse){ + iterator_B.load_with_ell_index(tb_frag_B, ell_iterator); + } else { + iterator_A.load_with_ell_index(tb_frag_A, ell_iterator); + } + } + + ++iterator_A; + ++iterator_B; + ++ell_iterator; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/index_remat.h b/include/cutlass/gemm/threadblock/index_remat.h new file mode 100644 index 00000000..e92bffdf --- /dev/null +++ b/include/cutlass/gemm/threadblock/index_remat.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helpers for rematerializing indices/dimensions in the thread hierarchy from special registers +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeThreadIdxX() { + return threadIdx.x; +} + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeThreadIdxY() { + return threadIdx.y; +} + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeThreadIdxZ() { + return threadIdx.z; +} + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockIdxX() { + return blockIdx.x; +} + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockIdxY() { + return blockIdx.y; +} + +/// Helper to rematerialize block Idx. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockIdxZ() { + return blockIdx.z; +} + +/// Helper to rematerialize block Dim. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockDimX() { + return blockDim.x; +} + +/// Helper to rematerialize block Dim. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockDimY() { + return blockDim.y; +} + +/// Helper to rematerialize block Dim. Reduces register liveness. +CUTLASS_DEVICE +int RematerializeBlockDimZ() { + return blockDim.z; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + + diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index d920e3b5..317f7c19 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -34,6 +34,7 @@ #pragma once + #include "cutlass/aligned_buffer.h" #include "cutlass/arch/memory.h" #include "cutlass/array.h" @@ -123,7 +124,7 @@ public: /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; - + /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -151,14 +152,40 @@ public: /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = + platform::is_same::value || + platform::is_same::value; + }; private: - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + WarpTransformedFragmentA warp_transformed_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_[2]; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + }; + private: @@ -166,12 +193,22 @@ public: // Data members // + /// Warp-level MMA operator + Operator warp_mma_; + /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA smem_iterator_A_; /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + public: /// Construct from tensor references @@ -188,7 +225,9 @@ public: ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: @@ -209,6 +248,45 @@ public: {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); } + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx_ = 0; + } + } + CUTLASS_DEVICE void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, int group_start_A = 0, int group_start_B = 0) { @@ -282,29 +360,19 @@ public: } } - /// Perform a threadblock-scoped matrix multiply-accumulate + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC &accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< initial value of accumulator - FragmentC const &src_accum) { - - // - // Prologue - // - + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { // Issue several complete stages CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; - ++stage, --gemm_k_iterations) { + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); iterator_B.clear_mask(gemm_k_iterations == 0); @@ -362,33 +430,22 @@ public: ++this->smem_iterator_B_; } - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); } - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - + // Optionally clear the remaining stages of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint are zero. if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - typename IteratorA::AccessType zero_A; - zero_A.clear(); + zero_A.clear(); last_smem_iterator_A.set_iteration_index(0); // Async Copy for operand A @@ -424,201 +481,259 @@ public: ++last_smem_iterator_B; } } + } - // Waits until stages up to the previous (kStages-2)th stage have committed. + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) cutlass::arch::cp_async_wait(); __syncthreads(); + } - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B[2]; - WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B[2]; - Operator warp_mma; + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); + // Load the next warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B_; - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; + // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary + if (warp_mma_k > 0) { + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); + } - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], - warp_loaded_frag_A[0], warp_loaded_frag_B[0]); - - // tf32x3 kernels use staging accumulation. warp_mma uses a temporary - // accumulator and this temporary accumulator is added to the final - // accumulator once in every mainloop iteration. - plus plus_accum; - - FragmentC tmp_accum; - - if (platform::is_same::value - || platform::is_same::value) { - - tmp_accum.clear(); - } - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; - ++warp_mma_k) { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - - this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); - this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - if (warp_mma_k > 0) - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B[warp_mma_k % 2]); - - if (platform::is_same::value - || platform::is_same::value) { - - warp_mma( - tmp_accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - tmp_accum - ); - - if (warp_mma_k == 0) { - accum = plus_accum(accum, tmp_accum); - tmp_accum.clear(); - } - } else { - warp_mma( - accum, - warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B[warp_mma_k % 2], - accum - ); + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); } + } else { + warp_mma_( + accum, + pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) { - int group_start_iteration_A, group_start_iteration_B; + // Except for the last warp-tile, all warp-tiles issue their share of + // global->shared fragment copies + if (warp_mma_k < Base::kWarpGemmIterations - 1) { - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); - } + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } - if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = - (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = - (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + // The second-to-last warp-tile also: + // - performs the last warp-tile's share of global->shared fragment copies + // - moves to the next global fetch stage + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, - group_start_iteration_B); + // Performs the last warp-tile's share of global->shared fragment copies + int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); - // Waits until stages up to the previous (kStages-2)th stage have committed. - arch::cp_async_wait(); - __syncthreads(); + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); + // Wait until we have at least one completed global fetch stage + gmem_wait(); - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B); + advance_smem_read_stage(); - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } else { - ++smem_write_stage_idx; - } + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } - if (smem_read_stage_idx == (Base::kStages - 1)) { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * - Base::kWarpGemmIterations, - 0}); - smem_read_stage_idx = 0; - } else { - ++smem_read_stage_idx; - } + // The last warp-tile also converts the shared memory fragments used by + // the first warp-tile of the next iteration, if necessary (so we can + // immediately start issuing MMA instructions at the top of the loop ) + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - - // Do any conversions feeding the first stage at the end of the loop so - // we can start right away on mma instructions - if (warp_mma_k + 1 == Base::kWarpGemmIterations) - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); } } + } - if (platform::is_same::value - || platform::is_same::value) { - accum = plus_accum(accum, tmp_accum); + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + ++this->warp_tile_iterator_B_; + + // Transform, if necessary, the first warp-tile's shared memory fragments + warp_mma_.transform( + pipe_state.warp_transformed_frag_A_[0], + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_A_[0], + pipe_state.warp_loaded_frag_B_[0]); + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); } - + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + gemm_k_iterations); + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Optionally commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop cutlass::arch::cp_async_fence(); cutlass::arch::cp_async_wait<0>(); __syncthreads(); } } + + + /// Prepares the class for another prologue. + CUTLASS_DEVICE + void wind_down() + { + // Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue) + + // First, increment remaining warp tiles to get to the next full stage. (Ideally we would + // just decrement one tile, but not all iterators implement --() decrement.) + #pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + smem_read_stage_idx_++; + + // Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators) + static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations; + if (smem_read_stage_idx_ > 1) + { + this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); + this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); + } + smem_read_stage_idx_ = smem_write_stage_idx_; + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -628,3 +743,4 @@ public: } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/threadblock/mma_pipelined.h b/include/cutlass/gemm/threadblock/mma_pipelined.h index 05ab53e2..494bd476 100644 --- a/include/cutlass/gemm/threadblock/mma_pipelined.h +++ b/include/cutlass/gemm/threadblock/mma_pipelined.h @@ -136,19 +136,30 @@ public: // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); -private: - - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - protected: + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma; + /// Iterator to write threadblock-scoped tile of A operand to shared memory SmemIteratorA smem_iterator_A_; /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + ///< transformation applied to A fragment + TransformA transform_A_; + + ///< transformation applied to B fragment + TransformB transform_B_; + + /// Shared memory write stage index + int smem_write_stage_idx; + public: /// Construct from tensor references @@ -157,11 +168,17 @@ public: typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp + int lane_idx, ///< ID of each thread within a warp + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB() ///< transformation applied to B fragment ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), - smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + transform_A_(transform_A), + transform_B_(transform_B), + smem_write_stage_idx(0) + { // Compute warp location within threadblock tile by mapping the warp_id to // three coordinates: @@ -180,69 +197,121 @@ public: this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); } - /// Perform a threadblock-scoped matrix multiply-accumulate + + /// Advance shared memory write-iterators to the next stage CUTLASS_DEVICE - void operator()( - int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC &accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - FragmentC const &src_accum, ///< source accumulator tile - TransformA transform_A = TransformA(), ///< transformation applied to A fragment - TransformB transform_B = TransformB()) { ///< transformation applied to B fragment - - // - // Prologue - // - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transform_A(tb_frag_A)); - this->smem_iterator_B_.store(transform_B(tb_frag_B)); - + void advance_smem_write_stage() + { ++this->smem_iterator_A_; ++this->smem_iterator_B_; + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + + smem_write_stage_idx ^= 1; + } + + /// Advance shared memory read- and write-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_stages() + { + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + // wrap write stage + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else + { + // wrap read stage + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + + smem_write_stage_idx ^= 1; + } + + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // The last kblock is loaded in the prolog + + // Load A fragment from global A + FragmentA tb_frag_A; + tb_frag_A.clear(); + iterator_A.load(tb_frag_A); + ++iterator_A; + + // Load B fragment from global B + FragmentB tb_frag_B; + tb_frag_B.clear(); + iterator_B.load(tb_frag_B); + ++iterator_B; + + // Store A and B fragments to shared + this->smem_iterator_A_.store(transform_A_(tb_frag_A)); + this->smem_iterator_B_.store(transform_B_(tb_frag_B)); + + // Advance write stage + advance_smem_write_stage(); + } + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { __syncthreads(); + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B) ///< [in|out] iterator over B operand in global memory + { + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; // Pair of fragments used to overlap shared memory loads and math instructions WarpFragmentA warp_frag_A[2]; WarpFragmentB warp_frag_B[2]; + // Load A fragment from shared A this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - ++this->warp_tile_iterator_A_; + + // Load B fragment from shared B + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); ++this->warp_tile_iterator_B_; - Operator warp_mma; - - int smem_write_stage_idx = 1; + // Pair of fragments used to overlap global memory loads and math instructions; + FragmentA tb_frag_A; + FragmentB tb_frag_B; // Avoid reading out of bounds iterator_A.clear_mask(gemm_k_iterations <= 1); iterator_B.clear_mask(gemm_k_iterations <= 1); - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - // // Mainloop // @@ -263,34 +332,20 @@ public: if (warp_mma_k == Base::kWarpGemmIterations - 1) { // Write fragments to shared memory - this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_A_.store(transform_A_(tb_frag_A)); - this->smem_iterator_B_.store(transform_B(tb_frag_B)); + this->smem_iterator_B_.store(transform_B_(tb_frag_B)); - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; + // Wait until we have at least one completed global fetch stage + gmem_wait(); - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, - 0}); - } - - smem_write_stage_idx ^= 1; + // Advance smem read and write stages + advance_smem_stages(); } this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); @@ -299,10 +354,14 @@ public: if (warp_mma_k == 0) { + // Load fragment from global A + tb_frag_A.clear(); iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - ++iterator_A; + + // Load fragment from global B + tb_frag_B.clear(); + iterator_B.load(tb_frag_B); ++iterator_B; // Avoid reading out of bounds if this was the last loop iteration @@ -310,12 +369,65 @@ public: iterator_B.clear_mask(gemm_k_iterations <= 2); } - warp_mma(accum, warp_frag_A[warp_mma_k % 2], - warp_frag_B[warp_mma_k % 2], accum); + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); } } } + + + /// Prepares the class for another prologue. + CUTLASS_DEVICE + void wind_down() + { + // First, increment remaining warp tiles to catch it up with the write stage. + #pragma unroll + for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k); + this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + + // If we bumped the read iterators to the end of the circular buffer, wrap them around to + // align them with the write iterators + if (smem_write_stage_idx == 0) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum) ///< source accumulator tile + { + // Prologue + prologue(iterator_A, iterator_B, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B); + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index 7f47a08b..b83280b5 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -41,6 +41,8 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/gemm/threadblock/index_remat.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle_streamk.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,62 +52,6 @@ namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxX() { - return threadIdx.x; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxY() { - return threadIdx.y; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeThreadIdxZ() { - return threadIdx.z; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxX() { - return blockIdx.x; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxY() { - return blockIdx.y; -} - -/// Helper to rematerialize block Idx. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockIdxZ() { - return blockIdx.z; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimX() { - return blockDim.x; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimY() { - return blockDim.y; -} - -/// Helper to rematerialize block Dim. Reduces register liveness. -CUTLASS_DEVICE -int RematerializeBlockDimZ() { - return blockDim.z; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Threadblock swizzling function for GEMMs template struct GemmIdentityThreadblockSwizzle { diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h new file mode 100644 index 00000000..e31d7c7f --- /dev/null +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -0,0 +1,778 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements streamk threadblock mapping blockIdx to GEMM problems. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/gemm/threadblock/index_remat.h" + +#include +#include "cutlass/core_io.h" +#include "cutlass/trace.h" + + + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Threadblock mapping control for GEMMs +struct ThreadblockSwizzleStreamK { + + /// Advertise StreamkFeature + using StreamkFeature = void; + + + /// Kernel traits + template + struct KernelTraits {}; + + + /// Reduction strategy + enum ReductionStrategy + { + kNone, // Data-parallel strategy (no seams, fixup, etc.) + + kAtomic, // Non-deterministic reduction of SK-block partials using atomic aggregation in L2 + + kMixed, // Deterministic reduction of SK-block partials employing either: + // (a) A separate wave of reduction thread blocks" (for scenarios with lots of + // SK-blocks per SK-tile) + // (b) Turnstile-ordered atomic aggregation in L2 (for scenarios with few + // SK-blocks per SK-tile) + }; + + static ReductionStrategy const kReductionStrategy = kMixed; + + + // + // Heuristics + // + + /// Data-parallel wave-quantization efficiency threshold (above which we go data-parallel) + static float constexpr kDpEfficiencyThreshold = 0.92f; + + /// Minimum number of MAC-iterations per streamk block + static int const kMinItersPerSkBlock = 2; + + /// Height in CTAs of a grid rasterization cohort + static int const kCohortCtasM = 8; + + /// Width in CTAs of a grid rasterization cohort + static int const kCohortCtasN = 4; + + /// Number of CTAs per cohort + static int const kCtasPerCohort = kCohortCtasN * kCohortCtasM; + + /// Cost-equivalent number of SM-iterations for fixup I/O + static int const kFixupStartupIterEquiv = 10; + static int const kFixupPeerIterEquiv = 3; + + + // + // Member state + // + + /// The 3D value-extents of the GEMM computation volume (m,n,k) + GemmCoord problem_size; + + /// The 2D tile-extents of the output matrix (m,n) + GemmCoord tiled_shape; + + /// Number of iterations per output tile + int iters_per_tile; + + /// Number of reduction blocks in the grid + int reduction_blocks; + + int dp_blocks; /// Number of data-parallel thread blocks in the grid + int dp_first_wave_tiles; /// Number of output tiles each CTA in the first DP wave will produce + + int sk_tiles; + int sk_regions; + int sk_blocks_per_region; + int sk_big_blocks_per_region; + int sk_iters_per_region; + int sk_iters_per_normal_block; /// Number of iterations for normal SK-blocks + int sk_waves; /// Number of SK waves in the grid + + /// CTA occupancy per SM + int sm_occupancy; + + /// Number of SMs for dispatch heuristics to load-balance using Stream-K CTAs (wave size) + int avail_sms; + + /// Whether to perform cohort CTA rasterization + bool cohort_raster; + + /// Div/mod accelerators + struct + { + FastDivmod tiled_shape_m; + FastDivmod tiled_shape_n; + FastDivmod tiled_cohort_shape_n; + FastDivmod iters_per_tile; + FastDivmod sk_iters_per_normal_block; + FastDivmod sk_iters_per_big_block; + FastDivmod sk_iters_per_region; + FastDivmod sk_blocks_per_region; + FastDivmod sm_occupancy; + + } div_mod; + + + // + // Host+device interface + // + + /// Constructor + CUTLASS_HOST_DEVICE + ThreadblockSwizzleStreamK() {} + + + + // + // Host-side interface + // + + /// Debug print + void Print() + { +#ifndef __CUDA_ARCH__ + int tiles = tiled_shape.m() * tiled_shape.n(); + std::cout << + "problem_size: (" << problem_size.m() << "," << problem_size.n() << ")" << + ", reduction_blocks: " << reduction_blocks << + ", dp_blocks: " << dp_blocks << + ", sk_blocks_per_region: " << sk_blocks_per_region << + ", sk_regions: " << sk_regions << + ", sk_iters_per_normal_block: " << sk_iters_per_normal_block << + ", sk_big_blocks_per_region: " << sk_big_blocks_per_region << + ", dp_first_wave_tiles: " << dp_first_wave_tiles << + ", tiled_shape: (" << tiled_shape.m() << "," << tiled_shape.n() << ")" << + ", tiles: " << tiles << + ", iters_per_tile: " << iters_per_tile << + ", dp_tiles: " << tiles - sk_tiles << + ", sk_tiles: " << sk_tiles << + ", avail_sms: " << avail_sms << + ", sm_occupancy: " << sm_occupancy << + ", avail_sms: " << avail_sms << + ", cohort_raster: " << cohort_raster << + "\n\n"; +#endif + } + + + // Compute sk_blocks to dispatch for a given number of sk_tiles + static void get_sk_blocks( + int &sk_blocks, /// [out] + int &savings_iters, /// [out] + int sk_tiles, + int iters_per_tile, + int avail_sms, + int max_sk_occupancy, + bool allow_partial_wave) + { + savings_iters = INT_MIN; + sk_blocks = 0; + + if (sk_tiles == 0) { + return; + } + + int sk_iters = sk_tiles * iters_per_tile; + + int dp_equiv_waves = (sk_tiles + avail_sms - 1) / avail_sms; + int dp_equiv_iters = iters_per_tile * dp_equiv_waves; + + int min_sk_blocks = (allow_partial_wave) ? fast_min(avail_sms, sk_tiles + 1) : avail_sms; + int max_sk_blocks = fast_min(avail_sms * max_sk_occupancy, sk_iters / kMinItersPerSkBlock); + + for (int trial_sk_blocks = min_sk_blocks; trial_sk_blocks <= max_sk_blocks; ++trial_sk_blocks) + { + int sk_waves = (trial_sk_blocks + avail_sms - 1) / avail_sms; + int max_sk_iters_per_block = (sk_iters + trial_sk_blocks - 1) / trial_sk_blocks; + int sk_iter_equiv = max_sk_iters_per_block * sk_waves; + + int num_peers = ((trial_sk_blocks + sk_tiles - 1) / sk_tiles) + 1; // add one for alignment skew + + float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv); + + if (trial_sk_blocks % sk_tiles == 0) + { + // aligned + num_peers = (trial_sk_blocks / sk_tiles); + + iter_cost = 0.0f; + } + + float peer_cost = 2.0f * float(num_peers); + + float base_cost = 2.0f * float(sk_waves); + + int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost); + + int trial_savings_iters = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv; + + if (trial_savings_iters >= savings_iters) { + savings_iters = trial_savings_iters; + sk_blocks = trial_sk_blocks; + } + } + } + + + /// Determine the populations of DP and SK blocks to invoke for the given number of output tiles + static void get_blocks( + int &dp_tiles, /// [out] + int &sk_blocks, /// [out] + int output_tiles, + int iters_per_tile, + int avail_sms, + int sm_occupancy) + { + int full_waves = output_tiles / avail_sms; + int full_wave_tiles = full_waves * avail_sms; + int partial_wave_tiles = output_tiles - full_wave_tiles; + + int score = -1; + dp_tiles = output_tiles; + sk_blocks = 0; + + if (partial_wave_tiles == 0) + { + // Perfect quantization + return; + } + + if (full_waves < sm_occupancy) + { + // We're less than full GPU occupancy + + // Form the SK wave from the partial wave to get us up to full GPU occupancy + int max_sk_occupancy = sm_occupancy - full_waves; + + dp_tiles = full_wave_tiles; + + get_sk_blocks( + sk_blocks, + score, + partial_wave_tiles, + iters_per_tile, + avail_sms, + max_sk_occupancy, + true); // we can run with less than a full wave of SK-blocks + + if (score < 0) { + // not profitable + sk_blocks = 0; + dp_tiles = output_tiles; + } + + return; + } + + // We're at (or greater) than GPU occupancy + + if (full_waves % sm_occupancy == sm_occupancy - 1) + { + // Form the SK wave from the partial wave to get us to full GPU occupancy + int max_sk_occupancy = 1; + + dp_tiles = full_wave_tiles; + + get_sk_blocks( + sk_blocks, + score, + partial_wave_tiles, + iters_per_tile, + avail_sms, + max_sk_occupancy, + true); // we can run with less than a full wave of SK-blocks + + if (score >= 0) { + return; + } + } + + // Form the SK wave by combining the last full wave and the partial wave + // We're less than full GPU occupancy + dp_tiles = full_wave_tiles - avail_sms; + + int max_sk_occupancy = sm_occupancy - ((full_waves - 1) % sm_occupancy); + + get_sk_blocks( + sk_blocks, + score, + partial_wave_tiles + avail_sms, + iters_per_tile, + avail_sms, + max_sk_occupancy, + false); // we cannot run with less than a full wave of SK-blocks + + if (score < 0) { + // not profitable + sk_blocks = 0; + dp_tiles = output_tiles; + } + + } + + /// Constructor: *Gemm* problem size (m, n, k) + template + ThreadblockSwizzleStreamK( + KernelTraits const kernel_traits_, + GemmUniversalMode const mode_, + GemmCoord const problem_size_, + GemmCoord const tile_size_, + int const batch_count_, /// Batch count (when mode_ == GemmUniversalMode::kBatched) or split-K-override splitting factor (when mode_ == GemmUniversalMode::kGemm) + int const sm_occupancy_, + int const avail_sms_) + : + problem_size(problem_size_), + tiled_shape( + (problem_size.m() + tile_size_.m() - 1) / tile_size_.m(), + (problem_size.n() + tile_size_.n() - 1) / tile_size_.n(), + (mode_ == GemmUniversalMode::kBatched) ? batch_count_ : 1), + iters_per_tile((problem_size.k() + tile_size_.k() - 1) / tile_size_.k()), + reduction_blocks(0), + dp_blocks(0), + dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks + sk_tiles(0), + sk_regions(1), // Default: a single region of iteration space (across all SK tiles) + sk_blocks_per_region(0), + sk_big_blocks_per_region(0), + sk_iters_per_region(0), + sk_iters_per_normal_block(0), + sk_waves(0), + sm_occupancy(sm_occupancy_), + avail_sms(fast_max(1, avail_sms_)), + cohort_raster(false) + { + size_t problem_bytes = + (sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) + + (sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) + + (sizeof(typename GemmKernel::ElementB) * problem_size.k() * problem_size.n()); + + size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2; + + float flops_per_byte = float(problem_flops) / float(problem_bytes); + + int gpu_occupancy = avail_sms * sm_occupancy; + int output_tiles = tiled_shape.m() * tiled_shape.n(); + int waves = (output_tiles + avail_sms - 1) / avail_sms; + float dp_efficiency = float(output_tiles) / float(waves * avail_sms); + + // + // Determine dispatch composition of DP-tiles and SK-blocks + // + + // Start with a DP-only configuration + int dp_tiles = output_tiles; // Number of data-parallel tiles + int sk_blocks = 0; // Number of thread blocks to produce the remaining SK tiles + + // kGemm mode allows for SK load balancing + if (mode_ == GemmUniversalMode::kGemm) + { + if (batch_count_ > 1) + { + // Split-K override + dp_tiles = 0; + sk_blocks = output_tiles * batch_count_; + } + else if ((kReductionStrategy != kNone) && // Load-balancing strategy statically enabled + (avail_sms > 1)) // Plurality of SMs to load balance across + { + // Use heuristics + get_blocks( + dp_tiles, /// [out] + sk_blocks, /// [out] + output_tiles, + iters_per_tile, + avail_sms, + sm_occupancy); + } + } + + sk_tiles = output_tiles - dp_tiles; + + + // Compute SK block iteration details + if (sk_blocks > 0) + { + sk_waves = (sk_blocks + avail_sms - 1) / avail_sms; + + int sk_iters = sk_tiles * iters_per_tile; + sk_blocks = fast_min(sk_blocks, sk_iters); + + sk_iters_per_normal_block = sk_iters / sk_blocks; + int extra_sk_iters = sk_iters - (sk_iters_per_normal_block * sk_blocks); + int sk_big_blocks = extra_sk_iters; + + if ((sk_blocks > sk_tiles) && (sk_blocks % sk_tiles == 0)) + { + // Split-K decomposition + sk_regions = sk_tiles; + } + + sk_blocks_per_region = sk_blocks / sk_regions; + sk_big_blocks_per_region = sk_big_blocks / sk_regions; + sk_iters_per_region = sk_iters / sk_regions; + + div_mod.sk_iters_per_normal_block = FastDivmod(sk_iters_per_normal_block); + div_mod.sk_iters_per_big_block = FastDivmod(sk_iters_per_normal_block + 1); + div_mod.sk_iters_per_region = FastDivmod(sk_iters_per_region); + div_mod.sk_blocks_per_region = FastDivmod(sk_blocks_per_region); + + // Separate reduction heuristic + if ((kReductionStrategy == kMixed) && + (sk_blocks > 2 * sk_tiles)) // Use a separate reduction wave whenever we would have more than three + // peers working on an SK tile. (This occurs when the ratio of SK-blocks + // to SK-tiles > 2, as a single tile may be covered by four SK-blocks, + // e.g.:[partial-block | block | block | partial-block] ). With three or + // less peers, the two non-finishing SK-blocks are not expexted to contend. + { + // Launch a reduction block every accumulator fragment in each SK-tile + static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments; + reduction_blocks = sk_tiles * kAccumulatorFragments; + + } + } + + // + // Compute DP blocks + // + + dp_blocks = dp_tiles; + + cutlass::gemm::GemmCoord tiled_cohort_shape( + (tiled_shape.m() + kCohortCtasM - 1) / kCohortCtasM, + (tiled_shape.n() + kCohortCtasN - 1) / kCohortCtasN, + batch_count_); + int cohort_blocks = (tiled_cohort_shape.m() * tiled_cohort_shape.n()) * kCtasPerCohort; + float cohort_efficiency = float(dp_blocks) / float(cohort_blocks); + + // Check if the SK tiles would be in cohorts that are in-bounds + bool sk_in_range = true; + if (sk_tiles > 0) + { + int last_sk_tile = sk_tiles - 1; + int cohort_tile_idx = last_sk_tile / kCtasPerCohort; + int cohort_grid_m = cohort_tile_idx / tiled_cohort_shape.n(); + int cohort_grid_n = (cohort_grid_m > 0) ? + tiled_cohort_shape.n() - 1 : + cohort_tile_idx % tiled_cohort_shape.n(); + + if ((((cohort_grid_m + 1) * kCohortCtasM) >= tiled_shape.m()) || + (((cohort_grid_n + 1) * kCohortCtasN) >= tiled_shape.n())) + { + sk_in_range = false; + } + } + + // Decide if we're going to be doing cohort raster + if (sk_in_range && + (dp_blocks >= gpu_occupancy) && + (cohort_efficiency > 0.85f)) + { + cohort_raster = true; + dp_blocks = cohort_blocks; + } + else if (sk_waves > 0) + { + // Update semi-persistence of first DP wave to ensure full grid wavesets + // (Only applies when there's an SK component and we're not doing blocked cohort rasterization) + int dp_tile_waves = (dp_tiles + avail_sms - 1) / avail_sms; + int full_dp_tile_waves = dp_tiles / avail_sms; + int waveset_excess = (sk_waves + dp_tile_waves) % sm_occupancy; + + if (dp_first_wave_tiles + waveset_excess <= full_dp_tile_waves) + { + dp_first_wave_tiles += waveset_excess; + dp_blocks -= (waveset_excess * avail_sms); + } + + } + + // Setup fast-div/mod for device-side usage + div_mod.tiled_shape_m = FastDivmod(tiled_shape.m()); + div_mod.tiled_shape_n = FastDivmod(tiled_shape.n()); + div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n()); + div_mod.iters_per_tile = FastDivmod(iters_per_tile); + div_mod.sm_occupancy = FastDivmod(sm_occupancy); + } + + + /// Constructor: *ImplicitGemm* Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + template + ThreadblockSwizzleStreamK( + KernelTraits kernel_traits_, + GemmUniversalMode mode_, + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size_, + GemmCoord tile_size_, + int batch_count_, + int sm_occupancy_, + int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs + int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles + : + ThreadblockSwizzleStreamK( + kernel_traits_, + mode_, + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_), + tile_size_, + batch_count_, + sm_occupancy_, + avail_sms_, + dp_tiles_, + sk_blocks_) + {} + + + /// Constructor: *ImplicitGemm* Conv3d problem size: conv_operator(NZPQK, NDHWC, KTRSC) + template + ThreadblockSwizzleStreamK( + KernelTraits kernel_traits_, + GemmUniversalMode mode_, + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem_size_, + GemmCoord tile_size_, + int batch_count_, + int sm_occupancy_, + int avail_sms_, /// When the below are defaulted, the number of SMs that dispatch heuristics will attempt to load-balance + int dp_tiles_ = -1, /// Dispatch override: number of output tiles to assign to independent, data-parallel CTAs + int sk_blocks_ = -1) /// Dispatch override: number of Stream-K CTAs for cooperatively processing the remaining output tiles + : + ThreadblockSwizzleStreamK( + kernel_traits_, + mode_, + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size_), + tile_size_, + batch_count_, + sm_occupancy_, + avail_sms_, + dp_tiles_, + sk_blocks_) + {} + + + /// Obtains number of threadblocks per GEMM + int get_num_blocks() const + { +// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms; +// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks; + + + int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks; + + if (work_blocks < avail_sms) + { + return work_blocks; + } + + int gpu_occupancy = sm_occupancy * avail_sms; + int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy; + return gpu_wavesets * gpu_occupancy; + + } + + + /// Obtains grid extents in CTAs + dim3 get_grid_dims() const + { + return dim3(get_num_blocks(), 1, tiled_shape.k()); + } + + + // + // Device-side interface + // + + /// Obtains number of threadblocks per GEMM + CUTLASS_DEVICE + int device_num_blocks() const + { + return gridDim.x; + } + + /// Obtains tile index for the given sk iteration + CUTLASS_DEVICE + int get_sk_tile_idx(int iter) const + { + return div_mod.iters_per_tile.div(iter); + } + + + /// Obtains the calling threadblock's tiled coordinates for the given tile index + CUTLASS_DEVICE + GemmCoord get_tile_offset(int tile_idx) const + { + int m, n; + + if (cohort_raster) + { + // tiled cohort raster + int cohort_tile_idx = tile_idx / kCtasPerCohort; + int cohort_grid_m, cohort_grid_n; + div_mod.tiled_cohort_shape_n(cohort_grid_m, cohort_grid_n, cohort_tile_idx); + + int block_idx_cohort = tile_idx % kCtasPerCohort; + int block_cohort_m = block_idx_cohort / kCohortCtasN; + int block_cohort_n = block_idx_cohort % kCohortCtasN; + + m = (cohort_grid_m * kCohortCtasM) + block_cohort_m; + n = (cohort_grid_n * kCohortCtasN) + block_cohort_n; + } + else if (tiled_shape.m() < tiled_shape.n()) + { + // column-major raster + div_mod.tiled_shape_m(n, m, tile_idx); + } + else + { + // row-major raster + div_mod.tiled_shape_n(m, n, tile_idx); + } + + int block_idx_k = RematerializeBlockIdxZ(); + return GemmCoord{m, n, block_idx_k}; + } + + + /// Obtains calling threadblock's linear threadblock index + CUTLASS_DEVICE + int get_block_idx() const + { + int block_idx = RematerializeBlockIdxX(); + + int gpu_occupancy = avail_sms * sm_occupancy; + int num_blocks = device_num_blocks(); + int dest_sm, dest_wave; + + div_mod.sm_occupancy(dest_sm, dest_wave, block_idx); + + int remapped_block_idx = dest_sm + (dest_wave * avail_sms); + + // remapping the first gpu_occupancy blocks + if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy)) + { + block_idx = remapped_block_idx; + } + + // Block-index is blockIdx.x for DP blocks + return block_idx; + } + + + /// Obtains calling linear threadblock index of the first block to work on the given tile + CUTLASS_DEVICE + int get_sk_block_idx(int iter) const + { + int region_idx; + int iter_in_region; + div_mod.sk_iters_per_region(region_idx, iter_in_region, iter); + + int big_block_iters = (sk_big_blocks_per_region * sk_iters_per_normal_block) + sk_big_blocks_per_region; // number of iterations in the region's big blocks + int normal_block_iters = iter_in_region - big_block_iters; // number of iterations in the region's normal bocks + + int big_block_idx_in_region = div_mod.sk_iters_per_big_block.div(iter_in_region); + int normal_block_idx_in_region = sk_big_blocks_per_region + div_mod.sk_iters_per_normal_block.div(normal_block_iters); + + int block_idx_in_region = (big_block_idx_in_region < sk_big_blocks_per_region) ? + big_block_idx_in_region : + normal_block_idx_in_region; + + return (sk_blocks_per_region * region_idx) + block_idx_in_region; + } + + /// Obtains iteration extends for the given SK block index + CUTLASS_DEVICE + void get_iter_extents( + int sk_block_idx, + int &block_iter_begin, + int &block_iter_end) const + { + int region_idx; + int block_idx_in_region; + div_mod.sk_blocks_per_region(region_idx, block_idx_in_region, sk_block_idx); + + block_iter_begin = (region_idx * sk_iters_per_region) + (block_idx_in_region * sk_iters_per_normal_block); + + // Adjust extents for the first "num_big_blocks" blocks that get one extra iteration + int block_iters = sk_iters_per_normal_block; + if (block_idx_in_region < sk_big_blocks_per_region) { + // This is a +1 iteration block + block_iter_begin += block_idx_in_region; + block_iters++; + } else { + // This is a regular block + block_iter_begin += sk_big_blocks_per_region; + } + block_iter_end = block_iter_begin + block_iters; + } + + + /// Obtains calling linear threadblock index of the first block to work on the given tile + CUTLASS_DEVICE + int get_first_block_idx(int tile_idx, int block_idx) const + { + if (tile_idx >= sk_tiles) { + // DP tile + return block_idx; + } + + int iter = tile_idx * iters_per_tile; + return get_sk_block_idx(iter); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h b/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h index c57053cc..57640ccb 100644 --- a/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/default_mma_complex_tensor_op.h @@ -214,7 +214,7 @@ struct DefaultMmaComplexTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization - input and output types are complex*complex // Use TF32 tensor operation internally -// 4 real-valued MMA.1688.F32.TF32 operations on TF32 +// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 // A = (ar + j ai), B (br +j bi), D = AB // D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -246,7 +246,7 @@ struct DefaultMmaComplexTensorOp< TransformB, arch::OpMultiplyAddComplex> { - // Complex floating point tensor operation use MMA.1688.F32.TF32 mma instruction + // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma< InstructionShape_, @@ -278,7 +278,7 @@ struct DefaultMmaComplexTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization - input and output types are complex*complex // Use BF16 tensor operation internally -// 4 real-valued MMA.1688.F32.BF16 operations on BF16 +// 4 real-valued mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 operations on BF16 // A = (ar + j ai), B (br +j bi), D = AB // D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -310,7 +310,7 @@ struct DefaultMmaComplexTensorOp< TransformB, arch::OpMultiplyAddFastBF16> { - // Complex floating point tensor operation use MMA.1688.F32.BF16 mma instruction + // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.bf16.bf16.f32 mma instruction using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma< InstructionShape_, @@ -342,7 +342,7 @@ struct DefaultMmaComplexTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization - input and output types are complex*complex // Use F16 tensor operation internally -// 4 real-valued MMA.1688.F32.F16 operations on F16 +// 4 real-valued mma.sync.aligned.m16n8k8.f32.f16.f16.f32 operations on F16 // A = (ar + j ai), B (br +j bi), D = AB // D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -374,7 +374,7 @@ struct DefaultMmaComplexTensorOp< TransformB, arch::OpMultiplyAddFastF16> { - // Complex floating point tensor operation use MMA.1688.F32.F16 mma instruction + // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.f16.f16.f32 mma instruction using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma< InstructionShape_, @@ -407,7 +407,7 @@ struct DefaultMmaComplexTensorOp< /// 3xTF32 or 4xTF32 (fast and accurate complex operation) /// Partial specialization - input and output types are complex * complex // Use 3xTF32 or 4xTF32 tensor operation internally -// 4 real-valued MMA.1688.F32.TF32 operations on TF32 +// 4 real-valued mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 operations on TF32 // A = (ar + j ai), B (br +j bi), D = AB // D = dr + j di = 3x[(ar*br - ai*bi) + j (ar*bi + ai*br)] ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -441,7 +441,7 @@ struct DefaultMmaComplexTensorOp< TransformB, arch::OpMultiplyAddComplexFastF32> { - // Complex floating point tensor operation use MMA.1688.F32.TF32 mma instruction + // Complex floating point tensor operation use mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 mma instruction using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma< InstructionShape_, @@ -470,6 +470,143 @@ struct DefaultMmaComplexTensorOp< TransformB>; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex*complex case +// 4 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations +// A = (ar + j ai), B (br +j bi), D = AB +// D = dr + j di = (ar*br - ai*bi) + j (ar*bi + ai*br) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Real-valued underlying type of complex-valued A operand + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Real-valued underlying type of complex-valued B operand + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Real-valued underlying type of complex-valued C operand + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + GemmShape<16, 8, 4>, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddComplex> { + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 4>, + 32, + RealElementA, + cutlass::layout::RowMajor, + RealElementB, + cutlass::layout::ColumnMajor, + RealElementC, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB, + true>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for complex*complex case using GaussianComplex operation +// 3 real-valued mma.sync.aligned.m16n8k4.f64.f64.f64.f64 operations +// A = (ar + j ai), B = (br +j bi), D = AB +// P1 = (ar + ai) * br, P2 = - ar * (br - bi), P3 = ai * (br + bi) +// D = dr + j di = (P1 - P3) + j (P1 + P2) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Real-valued underlying type of complex-valued A operand + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Real-valued underlying type of complex-valued B operand + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Real-valued underlying type of complex-valued C operand + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB> +struct DefaultMmaComplexTensorOp< + WarpShape_, + GemmShape<16, 8, 4>, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + TransformA, + TransformB, + arch::OpMultiplyAddGaussianComplex> { + + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma< + GemmShape<16, 8, 4>, + 32, + RealElementA, + cutlass::layout::RowMajor, + RealElementB, + cutlass::layout::ColumnMajor, + RealElementC, + cutlass::layout::RowMajor, + arch::OpMultiplyAdd>, + cutlass::MatrixShape<1, 1> + >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaGaussianComplexTensorOp< + WarpShape_, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + Policy, + TransformA, + TransformB, + true>; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace warp } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index 5054ddaf..2741ab58 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -46,6 +46,8 @@ #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm90.h" + #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" @@ -545,7 +547,7 @@ public: /// Partial specialization for complex*complex+complex => complex: // Operands data type: complex // Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Output data type: complex // ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -733,7 +735,7 @@ public: using MmaOperandC = typename ArchMmaOperator::FragmentC; static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, - "This implementation only supports MMA.1688 math instructions."); + "This implementation only supports mma.m16n8k8 math instructions."); static_assert(InstMmaOperandA::kElements == 4, "This implementation only supports math instructions in which exactly four element is needed for the A operand." @@ -846,6 +848,312 @@ public: } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for complex*complex+complex => complex: +// Operands data type: complex +// Math instruction: mma.sync.aligned.m16n8k4.f64.f64.f64.f64 +// Output data type: complex +// +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB +> +class MmaComplexTensorOp< + Shape_, + complex, + LayoutA_, + complex, + LayoutB_, + complex, + LayoutC_, + Policy_, + TransformA, + TransformB, + true> { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of members of complex multiplicand A + using RealElementA = double; + + /// Data type of multiplicand A + using ElementA = complex; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of members of complex multiplicand B + using RealElementB = double; + + /// Data type of multiplicand B + using ElementB = complex; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of members of complex accumulator matrix C + using RealElementC = double; + + /// Data type of accumulator matrix C + using ElementC = complex; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicyTensorOp) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Underlying arch tag + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Indicates math operator + using MathOperator = typename arch::OpMultiplyAddComplex; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + 32, + 1 + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kColumn, + 32, + 1 + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = FragmentB; + + static_assert( + !(Shape::kM % ArchMmaOperator::Shape::kM) && + !(Shape::kN % ArchMmaOperator::Shape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Shape::kM / ArchMmaOperator::Shape::kM, + Shape::kN / ArchMmaOperator::Shape::kN + >; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this + /// storage arrangement is to be considered 'planar complex' in the sense that all real-valued + /// parts are stored consecutively followed by all imaginary parts. This matches the structure + /// of Tensor Cores which are always real-valued matrix multiplies. + using FragmentC = typename IteratorC::Fragment; + + static_assert( + FragmentC::kElements == 2 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, + "Unexpected planar complex fragment length."); + +private: + + // + // Data members + // + + /// Underlying real-valued matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaComplexTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C + ) const { + + // Alias types for underlying real-valued matrix multiply operator + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + // mma(accum.real(), a.real(), b.real(), accum.real()); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_A; + MmaOperandB operand_B; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); + + // Real-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_A, operand_B, *accum); + } + + // mma(accum.imag(), a.real(), b.imag(), accum.imag()); + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_A; + MmaOperandB operand_B; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_A[mk] = A[m*MmaOperandA::kElements + mk].real(); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? + -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); + + // Complex-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_A, operand_B, *accum); + } + + // mma(accum.real(), -a.imag(), b.imag(), accum.real()) + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_A; + MmaOperandB operand_B; + + // A imaginary part is intentionally negated + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? + A[m*MmaOperandA::kElements + mk].imag() : -A[m*MmaOperandA::kElements + mk].imag()); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_B[nk] = (kTransformB == ComplexTransform::kConjugate ? + -B[n*MmaOperandB::kElements + nk].imag() : B[n*MmaOperandB::kElements + nk].imag()); + + // Real-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_A, operand_B, *accum); + } + + // mma(accum.imag(), a.imag(), b.real(), accum.imag()) + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_A; + MmaOperandB operand_B; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_A[mk] = (kTransformA == ComplexTransform::kConjugate ? + -A[m*MmaOperandA::kElements + mk].imag() : A[m*MmaOperandA::kElements + mk].imag()); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_B[nk] = B[n*MmaOperandB::kElements + nk].real(); + + // Complex-valued accumulator part + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_A, operand_B, *accum); + } + } + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + dst_A = A; + dst_B = B; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // TODO - partial specializations of real*complex and complex*real diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h b/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h index 99ae1964..20b78782 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op_fast_f32.h @@ -301,7 +301,7 @@ class MmaComplexTensorOpFastF32; /// Partial specialization for complex*complex+complex => complex: // Operands data type: complex // Rounding: float -> tfloat32_t (round half_ulp_truncate nearest) -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Output data type: complex // ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -497,7 +497,7 @@ public: using MmaOperandC = typename ArchMmaOperator::FragmentC; static_assert(platform::is_same, typename ArchMmaOperator::Shape>::value, - "This implementation only supports MMA.1688 math instructions."); + "This implementation only supports mma.m16n8k8 math instructions."); static_assert(InstMmaOperandA::kElements == 4, "This implementation only supports math instructions in which exactly four element is needed for the A operand." diff --git a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h index 4638fecc..7be564b1 100644 --- a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h @@ -84,6 +84,8 @@ template < ComplexTransform TransformA = ComplexTransform::kNone, /// Complex transform on B operand ComplexTransform TransformB = ComplexTransform::kNone, + /// Do source operands need more than one elements + bool GeneralizedOperatorElements = false, /// Used for partial specialization typename Enable = bool > @@ -112,9 +114,7 @@ template < /// Complex transform on A operand ComplexTransform TransformA, /// Complex transform on B operand - ComplexTransform TransformB, - /// Used for partial specialization - typename Enable + ComplexTransform TransformB > class MmaGaussianComplexTensorOp< Shape_, @@ -126,8 +126,7 @@ class MmaGaussianComplexTensorOp< LayoutC_, Policy_, TransformA, - TransformB, - Enable> { + TransformB> { public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; @@ -359,6 +358,282 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for complex*complex+complex => complex using real-valued TensorOps +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename RealElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename RealElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename RealElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Complex transform on A operand + ComplexTransform TransformA, + /// Complex transform on B operand + ComplexTransform TransformB +> +class MmaGaussianComplexTensorOp< + Shape_, + complex, + LayoutA_, + complex, + LayoutB_, + complex, + LayoutC_, + Policy_, + TransformA, + TransformB, + true> { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = complex; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = complex; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = complex; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Underlying arch tag + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Indicates math operator + using MathOperator = arch::OpMultiplyAddGaussianComplex; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + 32, + 1 + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kColumn, + 32, + 1 + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = FragmentB; + + static_assert( + !(Shape::kM % ArchMmaOperator::Shape::kM) && + !(Shape::kN % ArchMmaOperator::Shape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Shape::kM / ArchMmaOperator::Shape::kM, + Shape::kN / ArchMmaOperator::Shape::kN + >; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpGaussianComplexAccumulatorTileIterator< + MatrixShape, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile, the accumulator. Note, regardless of multiplicand type, this + /// storage arrangement is to be considered 'gaussian complex' in the sense that the accumulation is + /// done in three parts namely part1, part2, and part3. The parts 1, 2, and 3 are stored consecutively + /// in InteratorC::Frament. This matches the structure of Tensor Cores which are always real-valued matrix multiplies. + using FragmentC = typename IteratorC::Fragment; + + static_assert( + FragmentC::kElements == 3 * MmaIterations::kCount * ArchMmaOperator::FragmentC::kElements, + "Unexpected gaussian complex fragment length."); + +private: + + // + // Data members + // + + /// Underlying real-valued matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaGaussianComplexTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C + ) const { + + // Alias types for underlying real-valued matrix multiply operator + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + // mma(accum.part1(), (a.real() + a.imag()), b.real(), accum.part1()); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Asum; + MmaOperandB operand_Br; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_Asum[mk] = A[m*MmaOperandA::kElements + mk].real() + ((kTransformA == ComplexTransform::kConjugate) ? + -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag()); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_Br[nk] = B[n*MmaOperandB::kElements + nk].real(); + + // accumulator part1 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow); + + mma(*accum, operand_Asum, operand_Br, *accum); + } + + // mma(accum.part2(), -a.real(), (b.real() - b.imag()), accum.part2()); + CUTLASS_PRAGMA_UNROLL + for (int n = MmaIterations::kColumn - 1; n >= 0; --n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Ar; + MmaOperandB operand_Bdiff; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_Ar[mk] = -A[m*MmaOperandA::kElements + mk].real(); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_Bdiff[nk] = B[n*MmaOperandB::kElements + nk].real() - ((kTransformB == ComplexTransform::kConjugate) ? + -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); + + // accumulator part2 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + MmaIterations::kCount; + + mma(*accum, operand_Ar, operand_Bdiff, *accum); + } + + // mma(accum.part3(), a.imag(), (b.real() + b.imag()), accum.part3()) + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + // Pack operands together. This may result in actual MOVs + MmaOperandA operand_Ai; + MmaOperandB operand_Bsum; + + CUTLASS_PRAGMA_UNROLL + for (int mk = 0; mk < MmaOperandA::kElements; ++mk) + operand_Ai[mk] = (kTransformA == ComplexTransform::kConjugate) ? + -A[m*MmaOperandA::kElements + mk].imag() : +A[m*MmaOperandA::kElements + mk].imag(); + + CUTLASS_PRAGMA_UNROLL + for (int nk = 0; nk < MmaOperandB::kElements; ++nk) + operand_Bsum[nk] = B[n*MmaOperandB::kElements + nk].real() + ((kTransformB == ComplexTransform::kConjugate) ? + -B[n*MmaOperandB::kElements + nk].imag() : +B[n*MmaOperandB::kElements + nk].imag()); + + // accumulator part3 + MmaOperandC *accum = reinterpret_cast(&D) + + (m + n * MmaIterations::kRow) + 2 * MmaIterations::kCount; + + mma(*accum, operand_Ai, operand_Bsum, *accum); + } + } + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h index c26d4781..0b7ce5b3 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h @@ -493,7 +493,7 @@ public: Index elements_offset = layout_({WmmaShape::kRow, 0}); - byte_offset_ -= (elements_offset + sizeof_bits::value) / 8; + byte_offset_ -= (elements_offset * sizeof_bits::value) / 8; return *this; } diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 13d7146f..d1d70232 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -40,25 +40,8 @@ #endif #if defined(__CUDACC_RTC__) -/* All floating-point numbers can be put in one of these categories. */ -enum - { - FP_NAN = -# define FP_NAN 0 - FP_NAN, - FP_INFINITE = -# define FP_INFINITE 1 - FP_INFINITE, - FP_ZERO = -# define FP_ZERO 2 - FP_ZERO, - FP_SUBNORMAL = -# define FP_SUBNORMAL 3 - FP_SUBNORMAL, - FP_NORMAL = -# define FP_NORMAL 4 - FP_NORMAL - }; + +#include "cutlass/floating_point_nvrtc.h" // F16C extensions are not meaningful when compiling for NVRTC which only accommodates device code. #undef CUTLASS_ENABLE_F16C @@ -79,6 +62,7 @@ enum #include #include "cutlass/cutlass.h" +#include "cutlass/float8.h" #include "cutlass/platform/platform.h" /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -372,8 +356,7 @@ struct alignas(2) half_t { // /// Default constructor - CUTLASS_HOST_DEVICE - half_t() : storage(0) { } + half_t() = default; /// Reinterpret cast from CUDA's half type CUTLASS_HOST_DEVICE @@ -398,6 +381,18 @@ struct alignas(2) half_t { } + /// float_e4m3_t conversion + CUTLASS_HOST_DEVICE + explicit half_t(float_e4m3_t x): half_t(float(x)) { + + } + + /// float_e5m2_t conversion + CUTLASS_HOST_DEVICE + explicit half_t(float_e5m2_t x): half_t(float(x)) { + + } + /// Integer conversion - round to nearest even CUTLASS_HOST_DEVICE explicit half_t(int x) { @@ -618,19 +613,19 @@ struct numeric_limits { /// Returns smallest finite value static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } - /// Returns smallest finite value + /// Returns maximum rounding error static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } - /// Returns smallest finite value + /// Returns positive infinity value static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } - /// Returns smallest finite value + /// Returns quiet NaN value static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } - /// Returns smallest finite value + /// Returns signaling NaN value static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } - /// Returns smallest finite value + /// Returns smallest positive subnormal value static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; } // namespace std @@ -680,23 +675,23 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::half_t epsilon() { return cutlass::half_t::bitcast(0x1800); } - /// Returns smallest finite value + /// Returns maximum rounding error CUTLASS_HOST_DEVICE static cutlass::half_t round_error() { return cutlass::half_t(0.5f); } - /// Returns smallest finite value + /// Returns positive infinity value CUTLASS_HOST_DEVICE static cutlass::half_t infinity() { return cutlass::half_t::bitcast(0x7c00); } - /// Returns smallest finite value + /// Returns quiet NaN value CUTLASS_HOST_DEVICE static cutlass::half_t quiet_NaN() { return cutlass::half_t::bitcast(0x7fff); } - /// Returns smallest finite value + /// Returns signaling NaN value CUTLASS_HOST_DEVICE static cutlass::half_t signaling_NaN() { return cutlass::half_t::bitcast(0x7fff); } - /// Returns smallest finite value + /// Returns smallest positive subnormal value CUTLASS_HOST_DEVICE static cutlass::half_t denorm_min() { return cutlass::half_t::bitcast(0x0001); } }; diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 981c4cbd..a24389f8 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -77,8 +77,7 @@ struct integer_subbyte { // /// No operation - CUTLASS_HOST_DEVICE - integer_subbyte() { } + integer_subbyte() = default; /// Conversion from integer type CUTLASS_HOST_DEVICE diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index c7b01530..95fce480 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -72,9 +72,6 @@ private: Index stride_permute_; - Index col_permute_; - Index row_permute_; - public: // // Methods @@ -140,7 +137,7 @@ public: /// Computes the address offset after Permute Op in Bytes CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord offset_init) { - // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i, j, k, l], the dimension of X + // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. assert(extent_.row() % D1 == 0); assert(extent_.column() % D2 == 0); diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index f708f2f6..32cdcb31 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -974,7 +974,7 @@ struct NumericArrayConverter <= Array +/// Partial specialization for Array <= Array template < int N, FloatRoundStyle Round @@ -1271,6 +1271,849 @@ struct NumericArrayConverter { #endif +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out_fp16[2]; + uint32_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out_fp16[2]; + uint32_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(out_fp16[0]), "=r"(out_fp16[1]) : "r"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[0])); + float2 res1 = __half22float2(reinterpret_cast<__half2 &>(out_fp16[1])); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + out[2] = res1.x; + out[3] = res1.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e5m2_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = half_t; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out[2]; + uint32_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = half_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + uint32_t const* src_packed = reinterpret_cast(&source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = half_t; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out[2]; + uint32_t const& src_packed = reinterpret_cast(source); + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(out[0]), "=r"(out[1]) : "r"(src_packed)); + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e5m2_t; + using source_element = half_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out; + uint32_t const* src_packed = reinterpret_cast(&source); + + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "r"(src_packed[0]), "r"(src_packed[1])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = bfloat16_t; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert f8 to float + NumericArrayConverter src2float; + Array tmp_floats = src2float(source); + + // Convert float to bf16 + result_type out; + Array* packed_tmp = reinterpret_cast*>(&tmp_floats); + Array* packed_out = reinterpret_cast*>(&out); + NumericArrayConverter float2result; + packed_out[0] = float2result(packed_tmp[0]); + packed_out[1] = float2result(packed_tmp[1]); + + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = bfloat16_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert bf16 to float + Array tmp; + Array* packed_tmp = reinterpret_cast*>(&tmp); + Array const* packed_source = reinterpret_cast const*>(&source); + NumericArrayConverter src2float; + packed_tmp[0] = src2float(packed_source[0]); + packed_tmp[1] = src2float(packed_source[1]); + + // Convert float to f8 + NumericArrayConverter float2result; + return float2result(tmp); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = bfloat16_t; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert f8 to float + NumericArrayConverter src2float; + Array tmp_floats = src2float(source); + + // Convert float to bf16 + result_type out; + Array* packed_tmp = reinterpret_cast*>(&tmp_floats); + Array* packed_out = reinterpret_cast*>(&out); + NumericArrayConverter float2result; + packed_out[0] = float2result(packed_tmp[0]); + packed_out[1] = float2result(packed_tmp[1]); + + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e5m2_t; + using source_element = bfloat16_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + // Convert bf16 to float + Array tmp; + Array* packed_tmp = reinterpret_cast*>(&tmp); + Array const* packed_source = reinterpret_cast const*>(&source); + NumericArrayConverter src2float; + packed_tmp[0] = src2float(packed_source[0]); + packed_tmp[1] = src2float(packed_source[1]); + + // Convert float to f8 + NumericArrayConverter float2result; + return float2result(tmp); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e5m2_t; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + result[i] = converter(source[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for: +// Array <=> Array +// Array <=> Array +// +// These are needed to avoid multiple-matching-template compilation errors (e.g., when +// compiling float_e4m3_t <=> float_e4m3_t, which among T <= float_e4m3_t and float_e4m3_t <= T +// should be used?) +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return s; + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e5m2_t; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specialziations for: +// Array <=> Array +// Array <=> Array +// using packed converter under the hood +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename S, + int N, + FloatRoundStyle Round +> +struct PackedNumericArrayConverter { + using result_element = T; + using source_element = S; + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using packed_result_type = Array; + using packed_source_type = Array; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + result_type result; + packed_result_type* packed_result = reinterpret_cast(&result); + const packed_source_type* packed_source = reinterpret_cast(&source); + + NumericArrayConverter packed_converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + packed_result[i] = packed_converter(packed_source[i]); + } + + // Handle leftovers + NumericConverter converter; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N % 4; ++i) { + int idx = ((N / 4) * 4) + i; + result[idx] = converter(source[idx]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + typename S, + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter : + public PackedNumericArrayConverter {}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Array <= Array diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index c8ec08b7..04466ca7 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -88,6 +88,7 @@ using make_index_sequence = typename index_sequence_helper::type; #include "cutlass/half.h" #include "cutlass/bfloat16.h" #include "cutlass/tfloat32.h" +#include "cutlass/float8.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index ced1bef2..b2ec5370 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -574,7 +574,6 @@ using std::is_trivially_copyable; #endif - //----------------------------------------------------------------------------- // bit_cast //----------------------------------------------------------------------------- diff --git a/include/cutlass/quaternion.h b/include/cutlass/quaternion.h index c9af74ca..0f8a501a 100644 --- a/include/cutlass/quaternion.h +++ b/include/cutlass/quaternion.h @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/functional.h" #include "cutlass/array.h" #include "cutlass/real.h" #include "cutlass/coord.h" @@ -651,7 +652,10 @@ CUTLASS_HOST_DEVICE } }; -////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Factories +//////////////////////////////////////////////////////////////////////////////////////////////////// template <> CUTLASS_HOST_DEVICE @@ -671,6 +675,77 @@ cutlass::Quaternion from_real >(double r) { return cutlass::Quaternion(r); } +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Quaternion operator()(Quaternion lhs, Quaternion const &rhs) const { + lhs = lhs * rhs; + return lhs; + } +}; + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(Quaternion lhs) const { + multiplies mul_op; + + Output y_w = Output(lhs.w()); + Output y_x = Output(lhs.x()); + Output y_y = Output(lhs.y()); + Output y_z = Output(lhs.z()); + + return mul_op(y_w, y_w) + mul_op(y_x, y_x) + mul_op(y_y, y_y) + \ + mul_op(y_z, y_z); + } +}; + +template +struct multiply_add, Quaternion, Quaternion> { + CUTLASS_HOST_DEVICE + Quaternion operator()( + Quaternion const &a, + Quaternion const &b, + Quaternion const &c) const { + + T x = c.x(); + T y = c.y(); + T z = c.z(); + T w = c.w(); + + x += a.w() * b.x(); + x += b.w() * a.x(); + x += a.y() * b.z(); + x += -a.z() * b.y(), + + y += a.w() * b.y(); + y += b.w() * a.y(); + y += a.z() * b.x(); + y += -a.x() * b.z(); + + z += a.w() * b.z(); + z += b.w() * a.z(); + z += a.x() * b.y(); + z += -a.y() * b.x(); + + w += a.w() * b.w(); + w += -a.x() * b.x(); + w += -a.y() * b.y(); + w += -a.z() * b.z(); + + return cutlass::make_Quaternion(x, y, z, w); + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/semaphore.h b/include/cutlass/semaphore.h index 48f5b01a..7ba75ba8 100644 --- a/include/cutlass/semaphore.h +++ b/include/cutlass/semaphore.h @@ -36,7 +36,6 @@ #include "cutlass/cutlass.h" -#include "cutlass/aligned_buffer.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index b522b600..684236fb 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -34,7 +34,9 @@ */ #pragma once -#if !defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else #include #include #include @@ -87,22 +89,23 @@ struct alignas(4) tfloat32_t { } /// Default constructor - CUTLASS_HOST_DEVICE - tfloat32_t() : storage(0) { } + tfloat32_t() = default; /// Floating-point conversion - round toward nearest even CUTLASS_HOST_DEVICE - explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } +// explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } + tfloat32_t(float x): storage(round_half_ulp_truncate(x).storage) { } /// Floating-point conversion - round toward nearest even CUTLASS_HOST_DEVICE - explicit tfloat32_t(double x): tfloat32_t(float(x)) { - +// explicit tfloat32_t(double x): tfloat32_t(float(x)) { + tfloat32_t(double x): tfloat32_t(float(x)) { } /// Integer conversion - round toward zero CUTLASS_HOST_DEVICE - explicit tfloat32_t(int x) { +// explicit tfloat32_t(int x) { + tfloat32_t(int x) { float flt = static_cast(x); #if defined(__CUDA_ARCH__) storage = reinterpret_cast(flt); diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 803df22a..800f633b 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -88,20 +88,15 @@ struct PitchLinearStripminedThreadMap { static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); - static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * kElementsPerAccess)), - "Shape must be divisible thread count."); - /// Shape of the tile in units of vectors using ShapeVec = layout::PitchLinearShape< Shape::kContiguous / kElementsPerAccess, Shape::kStrided >; - static_assert( - (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || - (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), - "Shape must be divisible by number of iterations of each thread." - ); + static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || + (!(kThreads % ShapeVec::kContiguous)), + "Shape must be divisible by number of iterations of each thread."); }; /// Number of iterations by each thread @@ -112,11 +107,12 @@ struct PitchLinearStripminedThreadMap { // Redo the comparison here to work around divide by zero compiler // error. The compiler evaluates both path of platform::conditional. (Threads >= Detail::ShapeVec::kContiguous - ? Detail::ShapeVec::kStrided / + ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / (kThreads / Detail::ShapeVec::kContiguous) : 0)>, layout::PitchLinearShape>::type; + /// Interval between accesses along each dimension of the tensor's logical coordinate space /// (in units of Elements) @@ -132,6 +128,13 @@ struct PitchLinearStripminedThreadMap { > >::type; + /// Shape of the tile in units of vectors + using StorageShape = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space /// (in units of Elements) CUTLASS_HOST_DEVICE diff --git a/include/cutlass/transform/threadblock/ell_iterator.h b/include/cutlass/transform/threadblock/ell_iterator.h new file mode 100644 index 00000000..3557cd86 --- /dev/null +++ b/include/cutlass/transform/threadblock/ell_iterator.h @@ -0,0 +1,199 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for matrix of indices (ellColInd matrix) +*/ + +#pragma once + +namespace cutlass { +namespace transform { +namespace threadblock { + +namespace ell{ + +constexpr unsigned int SmemPow = 8; +constexpr unsigned int SmemStages = 2; +constexpr unsigned int SmemSize = 1 << SmemPow; +constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); + +class SharedStorage{ + public: + Array array; +}; + +class Iterator{ + public: + using Layout = layout::PitchLinear; + using LongIndex = typename Layout::LongIndex; + + private: + const int *gmem_col_idx_; + int *smem_col_idx_; + const int block_size_; + const int base_idx_; + const int k_shape_; + const int ell_increment_; + const int array_length_; + int col_idx_base_; + int residue_; + int counter_; + + int pow2_; + int residue_shape_; + + int smem_offset_; + int smem_stage_; + int gmem_offset_; + + int lane_; + + bool is_pow2_; + bool is_residue_tile_; + + public: + CUTLASS_DEVICE + void load_ell_indices(){ + for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; + } + gmem_offset_ += SmemSize; + smem_stage_ ^= 1; + } + + CUTLASS_DEVICE + Iterator( + SharedStorage& shared_storage_base, + const int* col_idx, + const int& block_size, + const int& base_idx, + const int k_shape, + const int& problem_size_k, + const int& ell_stride, + const int& thread_idx) + : residue_(0), + counter_(0), + smem_offset_(0), + smem_stage_(0), + gmem_offset_(0), + block_size_(block_size), + base_idx_(base_idx), + k_shape_(k_shape), + ell_increment_(ell_stride * block_size), + array_length_((problem_size_k + block_size_ - 1) / block_size_), + residue_shape_(problem_size_k % k_shape_), + is_residue_tile_(residue_shape_ != 0), + smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), + gmem_col_idx_(const_cast(col_idx)), + lane_(thread_idx % 32) { + + load_ell_indices(); + __syncthreads(); + + is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); + if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; + + pow2_ = 0; + while(block_size_ >> (pow2_ + 1)) ++pow2_; + } + + CUTLASS_DEVICE + int get_blocksize(){ + return block_size_; + } + + CUTLASS_DEVICE + Iterator &operator++(){ + if(is_residue_tile_){ + residue_ += residue_shape_; + is_residue_tile_ = false; + } else { + residue_ += k_shape_; + } + + if(residue_ < block_size_){ + return *this; + } + + if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) + load_ell_indices(); + + if(residue_ == block_size_){ + ++smem_offset_; + counter_ += ell_increment_; + residue_ = 0; + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + return *this; + } + + if(is_pow2_){ + smem_offset_ += residue_ >> pow2_; + counter_ += (residue_ >> pow2_) * ell_increment_; + residue_ = residue_ & ((1 << pow2_) - 1); + } + else { + smem_offset_ += residue_ / block_size_; + counter_ += (residue_ / block_size_) * ell_increment_; + residue_ %= block_size_; + } + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + + return *this; + } + + CUTLASS_DEVICE + LongIndex get_offset(const int& idx) { + int num_jump_tiles; + if(is_pow2_) + num_jump_tiles = (idx + residue_) >> pow2_; + else + num_jump_tiles = (idx + residue_) / block_size_; + + int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); + return tmp - num_jump_tiles * ell_increment_; + } + + CUTLASS_DEVICE + LongIndex get_offset_fast() { + return col_idx_base_; + } +}; + +} +} +} +} diff --git a/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h new file mode 100644 index 00000000..11bf4ded --- /dev/null +++ b/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h @@ -0,0 +1,1350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileAccessIterator +/// +template +class EllPredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + LongIndex stride_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_(layout.stride(0)) { + inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * + sizeof_bits::value / 8; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Offset to the first steady-state tile + TensorCoord residue_offset_; + + /// Initial offset to define ELL block + TensorCoord ell_offset_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent.strided()); + } else { + guard = (coord.contiguous() < extent.contiguous()); + } + } else { + guard = (coord.strided() < extent.strided() && + coord.contiguous() < extent.contiguous()); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residue_tile_(true) { + + TensorCoord residue_extent; + if (kAdvanceRank) { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + if (!residue_size) { + residue_size = Shape::kStrided; + } + + residue_offset_ = make_Coord(0, residue_size); + residue_extent = make_Coord( + extent_.contiguous(), + min(threadblock_offset.strided() + residue_size, extent_.strided()) + ); + } else { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; + if (!residue_size) { + residue_size = Shape::kContiguous; + } + + residue_offset_ = make_Coord(residue_size, 0); + + residue_extent = make_Coord( + min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), + extent_.strided() + ); + } + + // Per-thread offset in logical coordinates of tensor + ell_offset_ = ThreadMap::initial_offset(thread_id); + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(residue_extent, false); + + set_iteration_index(0); + } + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + thread_offset_ += residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(residue_offset_)); + + compute_predicates_(extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; + } + + /// Returns a k_location + CUTLASS_HOST_DEVICE + int get_k() const { + if(kAdvanceRank){ //strided + return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; + }else{ + return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; + } + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + if(kAdvanceRank) + return params_.stride_; + else + return 1; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + + Mask mask; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = ell_offset_ + iteration_coord; + + bool guard; + + if (kAdvanceRank == 0) { + guard = (coord.strided() < blocksize); + } else { + guard = (coord.contiguous() < blocksize); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] &= predicates_[i]; + } + set_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + AccessType>; + + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h new file mode 100644 index 00000000..ec65e373 --- /dev/null +++ b/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h @@ -0,0 +1,1315 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" +#include "cutlass/transform/threadblock/ell_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileIterator +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::EllPredicatedTileIterator; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess +> +class EllPredicatedTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + EllPredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return address_iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + address_iterator_.set_iteration_index(idx); + LongIndex ell_offset = 0; + + int k_offset = address_iterator_.get_k(); + ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { + + LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + } + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + }; + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped +/// to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is +/// mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index d45c4441..16a765dc 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -221,7 +221,10 @@ class PredicatedTileAccessIteratorPredicates { set_iteration_index(0); } - /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// Default constructor + PredicatedTileAccessIteratorPredicates() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID CUTLASS_HOST_DEVICE PredicatedTileAccessIteratorPredicates( @@ -360,9 +363,8 @@ class PredicatedTileAccessIterator, // /// Parameters object with precomputed internal state - Params const ¶ms_; + Params params_; /// Internal pointer to first access of tile BytePointer pointer_; @@ -1144,6 +1155,10 @@ class PredicatedTileAccessIterator, } public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID CUTLASS_HOST_DEVICE @@ -1371,9 +1386,8 @@ class PredicatedTileAccessIterator tensor's layout CUTLASS_HOST_DEVICE @@ -1390,6 +1404,10 @@ class PredicatedTileAccessIterator tensor's layout CUTLASS_HOST_DEVICE @@ -1569,6 +1586,10 @@ class PredicatedTileAccessIterator, AdvanceRa public: /// Construct the Params object given a pitch-linear tensor's layout CUTLASS_HOST_DEVICE - Params(Layout const &layout) : params_(layout) { } - - CUTLASS_HOST_DEVICE - Params() { } + Params(Layout const &layout) : params_(layout) {} + + /// Default constructor + Params() = default; }; private: @@ -897,6 +906,10 @@ class PredicatedTileIterator, AdvanceRa TileAccessIterator address_iterator_; public: + + /// Default constructor + PredicatedTileIterator() = default; + /// Constructs a TileIterator from its precomputed state, threadblock offset, /// and thread ID CUTLASS_HOST_DEVICE @@ -1123,15 +1136,14 @@ public: typename UnderlyingIterator::Params params_; public: - - CUTLASS_HOST_DEVICE - Params() { } + + /// Default constructor + Params() = default; /// Construct the Params object given an AffineRankN<2> tensor's layout CUTLASS_HOST_DEVICE - Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) { - - } + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) + {} }; private: @@ -1145,6 +1157,9 @@ private: public: + /// Default constructor + PredicatedTileIterator() = default; + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID CUTLASS_HOST_DEVICE PredicatedTileIterator( @@ -1329,9 +1344,9 @@ public: typename UnderlyingIterator::Params params_; public: - - CUTLASS_HOST_DEVICE - Params() { } + + /// Default constructor + Params() = default; /// Construct the Params object given an AffineRankN<2> tensor's layout CUTLASS_HOST_DEVICE @@ -1350,6 +1365,9 @@ private: public: + /// Default constructor + PredicatedTileIterator() = default; + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID CUTLASS_HOST_DEVICE PredicatedTileIterator( @@ -1530,8 +1548,9 @@ class PredicatedTileIterator; @@ -1718,8 +1741,9 @@ class PredicatedTileIterator::value* ThreadMap::kElementsPerAccess / 8 + > +class RegularTileAccessIteratorDirectConv; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_, false, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + //Do nothing + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * ThreadMap::Iterations::kStrided * + ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_,true, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + /// Total iterattions in the strided dimension: Dynamic value + int total_iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + total_iteration_strided_ = num; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < total_iteration_strided_) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * + ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for column major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::ColumnMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::RowMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index ce3b6c47..116b9502 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -77,6 +77,8 @@ struct uint128_t { uint64_t lo; uint64_t hi; + hilo() = default; + CUTLASS_HOST_DEVICE hilo(uint64_t lo_, uint64_t hi_):lo(lo_), hi(hi_) {} }; @@ -94,8 +96,7 @@ struct uint128_t { // /// Default ctor - CUTLASS_HOST_DEVICE - uint128_t(): hilo_(0, 0) { } + uint128_t() = default; /// Constructor from uint64 CUTLASS_HOST_DEVICE @@ -222,6 +223,7 @@ struct uint128_t { quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else // TODO - not implemented + CUTLASS_UNUSED(remainder); CUTLASS_UNUSED(divisor); exception(); #endif diff --git a/include/cutlass/wmma_array.h b/include/cutlass/wmma_array.h index fede386c..46d787b8 100644 --- a/include/cutlass/wmma_array.h +++ b/include/cutlass/wmma_array.h @@ -41,6 +41,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/functional.h" namespace cutlass { @@ -55,17 +56,31 @@ template < > class WmmaFragmentArray: public Array { public: + /// Efficient clear method (override Array::clear()) CUTLASS_HOST_DEVICE - void clear() { - - for(int i=0; i::kElements; i++) { - + void clear() + { + for(int i = 0; i < Array::kElements; i++) + { nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); + } + } + CUTLASS_HOST_DEVICE + WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) + { + using element_type = typename T::element_type; + plus add; + + for (int i = 0; i < Array::kElements; i++) + { + (*this)[i] = add((*this)[i], rhs[i]); } + return *this; } + }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/media/docs/efficient_gemm.md b/media/docs/efficient_gemm.md index 203f6810..a36bf3b3 100644 --- a/media/docs/efficient_gemm.md +++ b/media/docs/efficient_gemm.md @@ -4,21 +4,21 @@ # Efficient GEMM in CUDA -CUTLASS implements the hierarchically blocked structure described in +CUTLASS implements the hierarchically blocked structure described in [CUTLASS: Fast Linear Algebra in CUDA C++](https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/) and the [CUTLASS GTC2018 talk](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). ## Hierarchical Structure The basic triple loop nest computing matrix multiply may be blocked and tiled to match -concurrency in hardware, memory locality, and parallel programming models. In CUTLASS, +concurrency in hardware, memory locality, and parallel programming models. In CUTLASS, GEMM is mapped to NVIDIA GPUs with the structure illustrated by the following loop nest. ```c++ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // for each threadblock_y } threadblock-level concurrency for (int cta_m = 0; cta_m < GemmM; cta_m += CtaTileM) { // for each threadblock_x } - for (int cta_k = 0; cta_k < GemmK; cta_k += CtaTileK) { // "GEMM mainloop" - no unrolling + for (int cta_k = 0; cta_k < GemmK; cta_k += CtaTileK) { // "GEMM mainloop" - no unrolling // - one iteration of this loop is one "stage" // for (int warp_n = 0; warp_n < CtaTileN; warp_n += WarpTileN) { // for each warp_y } warp-level parallelism @@ -30,7 +30,7 @@ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // f for (int mma_k = 0; mma_k < WarpTileK; mma_k += MmaK) { // for each mma instruction } instruction-level parallelism for (int mma_n = 0; mma_n < WarpTileN; mma_n += MmaN) { // for each mma instruction } for (int mma_m = 0; mma_m < WarpTileM; mma_m += MmaM) { // for each mma instruction } - // + // mma_instruction(d, a, b, c); // TensorCore matrix computation } // for mma_m @@ -47,17 +47,17 @@ for (int cta_n = 0; cta_n < GemmN; cta_n += CtaTileN) { // f ``` This tiled loop nest targets concurrency among -- threadblocks -- warps -- CUDA and Tensor Cores +- threadblocks, +- warps, and +- CUDA and Tensor Cores. -and takes advantage of memory locality within -- shared memory -- registers +It takes advantage of memory locality within +- shared memory and +- registers. -The flow of data within this structure is illustrated below. -This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a -nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a +The figure below illustrates the flow of data within this structure. +This is the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a +nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right. ![ALT](/media/images/gemm-hierarchy-with-epilogue.png "Hierarchical GEMM in CUDA") @@ -66,20 +66,19 @@ level within the memory hierarchy, becoming increasingly finer moving left to ri ### Threadblock-level GEMM Each threadblock computes its portion of the output GEMM by iteratively loading tiles of input -matrices and computing an accumulated matrix product. At the threadblock level, data is loaded from -global memory. The blocking strategy in general is key to achieving efficiency. However, there are -multiple conflicting goals that a programmer aims to achieve to strike a reasonable compromise. A +matrices and computing an accumulated matrix product. At the threadblock level, data are loaded from +global memory. The blocking strategy in general is key to achieving efficiency. However, the programmer +must balance multiple conflicting goals. A larger threadblock means fewer fetches from global memory, thereby ensuring that DRAM bandwidth -does not become a bottleneck. - +does not become a bottleneck. However, large threadblock tiles may not match the dimensions of the problem well. If either the GEMM _M_ or _N_ dimension is small, some threads within the threadblock may not perform meaningful work, as the threadblock may be partially outside the bounds of the problem. If both _M_ and _N_ are small while _K_ is large, this scheme may launch relatively few threadblocks and fail to -fully utilize all multiprocessors within the GPU. Strategies to optimize performance for this case -are described in the section [Parallelized Reductions](efficient_gemm.md#parallelized-reductions) -which partition the GEMM K dimension across multiple threadblocks or multiple warps. These compute -matrix products in parallel which is then reduced to compute the result. +make full use of all multiprocessors within the GPU. Strategies to optimize performance for this case, +as described in the section [Parallelized Reductions](efficient_gemm.md#parallelized-reductions), +partition the GEMM K dimension across multiple threadblocks or multiple warps. These threadblocks +or warps compute matrix products in parallel; the products are then reduced to compute the result. In CUTLASS, the dimensions of the threadblock tile are specified as `ThreadblockShape::{kM, kN, kK}` and may be tuned to specialize the GEMM computation for the target processor and dimensions of @@ -90,10 +89,10 @@ the GEMM problem. The warp-level GEMM maps to the warp-level parallelism within the CUDA execution model. Multiple warps within a threadblock fetch data from shared memory into registers and perform computations. -Warp-level GEMMs may be implemented either by TensorCores issuing -[mma.sync](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) -or [wmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma) -instructions or by thread-level matrix computations issued to CUDA cores. +Warp-level GEMMs may be implemented either by TensorCores issuing +[mma.sync](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) +or [wmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma) +instructions, or by thread-level matrix computations issued to CUDA cores. For maximum performance, access to shared memory should be bank conflict free. To maximize data reuse within the warp, a large warp-level GEMM tile should be chosen. @@ -101,8 +100,8 @@ reuse within the warp, a large warp-level GEMM tile should be chosen. ### Thread-level GEMM At the lowest level of blocking, each thread is responsible for processing a certain number of -elements. Threads cannot access each other's registers so we choose an organization that enables -values held in registers to be reused for multiple math instructions. This results in a 2D tiled +elements. Threads cannot access each other's registers, so we choose an organization that enables +reuse of values held in registers for multiple math instructions. This results in a 2D tiled structure within a thread, in which each thread issues a sequence of independent math instructions to the CUDA cores and computes an accumulated outer product. @@ -127,31 +126,33 @@ but other device-side function call operators may be used to perform custom oper ## Optimizations -The hierarchical structure described above yields an efficient mapping to the CUDA execution model and +The hierarchical structure described above yields an efficient mapping to the CUDA execution model and CUDA/TensorCores in NVIDIA GPUs. The following sections describe strategies for obtaining peak performance for all corners of the design space, maximizing parallelism and exploiting data locality wherever possible. ### Pipelining The blocked structure demands a large storage allocation within the registers of each CUDA thread. The -accumulator elements typically occupy at least half a thread's total register budget. Consequently, +accumulator elements typically occupy at least half a thread's total register budget. Consequently, occupancy -- the number of concurrent threads, warps, and threadblocks -- is relatively low compared -to other classes of GPU workloads. This limits the GPUs ability to hide memory latency and other stalls +to other classes of GPU workloads. This limits the GPU's ability to hide memory latency and other stalls by context switching to other concurrent threads within an SM. -To mitigate the effects of memory latency, *software pipelining* is used to overlap memory accesses -with other computation within a thread. In CUTLASS, this is achieved by double buffering at the -following scopes +To mitigate the effects of memory latency, CUTLASS uses *software pipelining* to overlap memory accesses +with other computation within a thread. CUTLASS accomplishes this by double buffering at the +following scopes. -- **threadblock-scoped shared memory tiles:** two tiles are allocated within shared memory; one is used - load data for the current matrix operation, while the other tile is used to buffer data loaded from - global memory for the next mainloop iteration +- **Threadblock-scoped shared memory tiles:** two tiles are allocated in shared memory. + One is used to load data for the current matrix operation, + while the other tile is used to buffer data loaded from global memory + for the next mainloop iteration. -- **warp-scoped matrix fragments:** two fragments are allocated within registers; one fragment is passed - to CUDA and TensorCores during the current matrix computation, while the other is used to receive - shared memory fetch returns for the next warp-level matrix operation +- **Warp-scoped matrix fragments:** two fragments are allocated within registers. + One fragment is passed to CUDA and TensorCores during the current matrix computation, + while the other is used to receive shared memory fetch returns + for the next warp-level matrix operation. -The efficient, pipelined mainloop body used in CUTLASS GEMMs is illustrated as follows. +The following diagram illustrates the efficient, pipelined mainloop body used in CUTLASS GEMMs. ![ALT](/media/images/software-pipeline.png "Software pipeline in CUTLASS") @@ -181,35 +182,42 @@ benefits of large threadblock-level GEMM tiles. CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension and launching an additional set of threadblocks for each partition. Consequently, we refer to -this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass -requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called -batched reduction. +this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" strategy +requires the execution of 2 kernels: partitionedK GEMM, and batched reduction. -The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users -to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the -number of partition that will be applied along K dimension for operand A and B. For example, parameters o -f m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of -m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. +PartitionedK GEMM resembles one flavor of batched strided GEMM. Instead of requiring users +to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the +number of partitions that will be applied along the K dimension for operands A and B. For example, +parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs +with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible +by the partition count. -For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs -with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220. +For example, parameters of m=128, n=128, k=4096 and partition=20 +will result in 20 batched strided GEMMs. +The first 19 batches will have m=128, n=128, and k=4096/20=204, +and the last batch will have m=128, n=128, and k=220. -The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of -the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by -the users to store this intermediate results. +The batched reduction kernel takes as input the output (C) of partitionedK GEMM, +and performs a reduction along the K-dimension. +Users must manage workspace memory to store this intermediate result. **Sliced K - reduction across warps** -Similar to the split-k scenario, sliced-k aims at improving the efficiency of kernels with smaller M, N, - but large K dimensions. In general at the thread-block level, the parameters CtaTileN, CtaTileM expose parallelism -by partitioning the the work the among warps, and larger warpTiles expose better ILP (Instruction -level parallelism) and reuse, but it also limits the number of warps running per thread-block, which reduces efficiency. +Similar to the split-k scenario, sliced-k aims at improving the efficiency of kernels +with smaller M and N dimensions, but large K dimension. +At the thread-block level, the parameters CtaTileN and CtaTileM expose parallelism +by partitioning the work among warps. +Larger warpTiles expose better instruction-level parallelism (ILP) and reuse, +but also limit the number of warps running per threadblock, which reduces efficiency. -So in order to improve efficiency in such scenarios, partitioning the warpTiles also along ctaTileK helps improve the utilization -of the underlying hardware by allowing more warps to run concurrently in a CTA. Now, since sliced-k kernels breaks -down a thread-blocks's computation among participating warps not just among the CtaTileN, CtaTileM dimension, -but also the CtaTileK dimension it entails a small cost in form of a reduction which has to happen at the end among the -participating warps - since each warp now owns a partial sum (since they compute using only a "slice" of ctaTileK). +In order to improve efficiency in such scenarios, partitioning the warpTiles also along ctaTileK +helps use the hardware more efficiently by allowing more warps to run concurrently in a CTA. +Sliced-k kernels break down a threadblock's computation among participating warps +not just among the CtaTileN, CtaTileM dimension, but also the CtaTileK dimension. +Thus, sliced-k entails a small cost in form of a reduction +which has to happen at the end among the participating warps. +This is because each warp computes using only a "slice" of CtaTileK, +so each warp only has a partial sum before the reduction. # Resources diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 4f024319..7810ced1 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -78,6 +78,7 @@ void FilterArchitecture() { { "SM70*", 70, 75}, { "SM75*", 75, kMaxDevice}, { "SM80*", 80, kMaxDevice}, + { "SM90*", 90, kMaxDevice}, { 0, 0, false } }; diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 3d3cdc80..47a6acdf 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -110,7 +110,9 @@ cutlass_test_unit_add_executable( # F16 conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu - depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu + depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu + depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu + depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu ) if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index 01a90910..773087ac 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -776,6 +776,29 @@ struct TestbedGroupConv2dProblemSizes { 2 // groups )); + // Larger problem sizes + + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 696}, // input size (NHWC) + {768, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 14, 14, 1392}, // input size (NHWC) + {1536, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + //////////////////////////////////////////////////////////////////////////////////// // One CTA calculate multiple groups: CTA::N % k_per_group = 0 //////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 0e9ac9a6..23f1c5ad 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -192,7 +192,7 @@ public: // Determine SMEM requirements and waive if not satisfied // - int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); cudaDeviceProp properties; int device_idx; @@ -208,7 +208,7 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -305,15 +305,15 @@ public: cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), { reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C {alpha, beta} @@ -637,7 +637,7 @@ bool TestAllConv2d( // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -645,17 +645,17 @@ bool TestAllConv2d( } // Fixed channels algorithm requires channel count to match access size - if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm == + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == cutlass::conv::IteratorAlgorithm::kFixedChannels) { - if (conv_problem.C != ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::AccessType::kElements) { + if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { continue; } } // Few channels algorithm requires channel count to match access size - if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm == + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == cutlass::conv::IteratorAlgorithm::kFewChannels) { - if (conv_problem.C % ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::AccessType::kElements) { + if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { continue; } } @@ -665,7 +665,7 @@ bool TestAllConv2d( // to run strided dgrad for non-unity strides if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -704,14 +704,14 @@ bool TestAllConv2d( } // Small-channels convolution can't run here. - if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm == + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == cutlass::conv::IteratorAlgorithm::kFixedChannels) { return true; } // Small-channels convolution can't run here. - if (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kIteratorAlgorithm == + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == cutlass::conv::IteratorAlgorithm::kFewChannels) { return true; @@ -720,7 +720,7 @@ bool TestAllConv2d( // CUTLASS DGRAD's *strided* specialization does not support split-k mode if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { passed = testbed.run( diff --git a/test/unit/conv/device/conv2d_testbed_interleaved.h b/test/unit/conv/device/conv2d_testbed_interleaved.h index 2aa60f0b..7fda5c17 100644 --- a/test/unit/conv/device/conv2d_testbed_interleaved.h +++ b/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -257,15 +257,15 @@ public: cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), { reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv2d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) }, // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C {alpha, beta} @@ -536,7 +536,7 @@ bool TestAllInterleavedConv2d( // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index dd12bf60..0fd2843b 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -253,7 +253,7 @@ public: // Determine SMEM requirements and waive if not satisfied // - int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); cudaDeviceProp properties; int device_idx; @@ -269,7 +269,7 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -557,7 +557,7 @@ bool TestAllConv2dWithBroadcast( // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -568,7 +568,7 @@ bool TestAllConv2dWithBroadcast( // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -605,7 +605,7 @@ bool TestAllConv2dWithBroadcast( // CUTLASS DGRAD's *strided* specialization does not support split-k mode if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { passed = testbed.run( diff --git a/test/unit/conv/device/conv2d_with_reduction_testbed.h b/test/unit/conv/device/conv2d_with_reduction_testbed.h index a147275b..479323ec 100644 --- a/test/unit/conv/device/conv2d_with_reduction_testbed.h +++ b/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -182,7 +182,7 @@ public: // Determine SMEM requirements and waive if not satisfied // - int smem_size = int(sizeof(typename Conv2d::ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename Conv2d::UnderlyingKernel::SharedStorage)); cudaDeviceProp properties; int device_idx; @@ -198,7 +198,7 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -516,7 +516,7 @@ bool TestAllConv2dWithReduction( // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -527,7 +527,7 @@ bool TestAllConv2dWithReduction( // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { continue; @@ -564,7 +564,7 @@ bool TestAllConv2dWithReduction( // CUTLASS DGRAD's *strided* specialization does not support split-k mode if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { passed = testbed.run( diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index f9cc3563..34dd51f6 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -184,7 +184,7 @@ public: // Determine SMEM requirements and waive if not satisfied // - int smem_size = int(sizeof(typename Conv3d::ImplicitGemmKernel::SharedStorage)); + int smem_size = int(sizeof(typename Conv3d::UnderlyingKernel::SharedStorage)); cudaDeviceProp properties; int device_idx; @@ -200,7 +200,7 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -294,15 +294,15 @@ public: cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), { reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_D_computed.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) }, { tensor_C.device_data(), - ReductionStrideIndex(tensor_C.stride()[Conv3d::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) }, // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C {alpha, beta} @@ -573,9 +573,9 @@ bool TestAllConv3d( // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && - ((ImplicitGemm::ImplicitGemmKernel::Mma::IteratorA::kStrideSupport == + ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity) || - (ImplicitGemm::ImplicitGemmKernel::Mma::IteratorB::kStrideSupport == + (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == cutlass::conv::StrideSupport::kUnity))) { if (!((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && diff --git a/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h new file mode 100644 index 00000000..22ee9e38 --- /dev/null +++ b/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Depthwise Direct Conv testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cache_testbed_output.h" +#include "conv2d_problems.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedDepthwiseDirectConv2d { + public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + public: + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_reordered_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + + public: + TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} + + /// Helper to initialize a tensor view + template + void initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } else { + scope = 5; + } + } else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } else { + } + } + + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_reordered_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient(int smem_size) const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1.5), + ElementCompute beta = ElementCompute(1)) { + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " + << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") + << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args(problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + tensor_reordered_B.device_ref(), + split_k_mode); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.can_implement(problem_size); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + if (!sufficient(conv2d_op.get_smem_size())) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = + CreateCachedConv2dTestKey(kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view()); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_DirectConv_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { + bool passed = true; + + // + // Testbed object + // + TestbedDepthwiseDirectConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (auto conv_problem : problem_sizes) { + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu new file mode 100644 index 00000000..61d75d52 --- /dev/null +++ b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -0,0 +1,426 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Depthwise Direct Conv interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "conv2d_testbed.h" +#include "depthwise_conv2d_direct_conv_testbed.h" + +std::vector DepthwiseFpropProblemSizes_filter3x3() { + std::vector problems; + + for (int channels = 16; channels <= 512; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + + // if(channels == 512 || channels == 16*14) + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + + return problems; +} + +std::vector DepthwiseFpropProblemSizes_filter5x5() { + std::vector problems; + + for (int channels = 16; channels < 256; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + + return problems; +} + +std::vector DepthwiseFpropProblemSizes_filter5x37() { + std::vector problems; + + for (int channels = 16; channels < 256; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 128, 128, channels}, // input size (NHWC) + {channels, 5, 37, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 108, // split_k_slices + channels // groups + )); + } + + return problems; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x32_4_8x32_3x3) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 32; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<3, 3>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 4; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter3x3())); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x64_3_16x64_5x5) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 64; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<5, 5>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter5x5())); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_Optimized_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x32_3_16x32_5x37) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 32; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<5, 37>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 2; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter5x37())); +} diff --git a/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu new file mode 100644 index 00000000..080eb65b --- /dev/null +++ b/test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -0,0 +1,522 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Depthwise Direct Conv interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "conv2d_testbed.h" +#include "depthwise_conv2d_direct_conv_testbed.h" + +std::vector DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1() { + std::vector problems; + + for (int channels = 16; channels <= 512; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + return problems; +} + +std::vector DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2() { + std::vector problems; + for (int channels = 16; channels <= 512; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + + return problems; +} + +std::vector DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1() { + std::vector problems; + + for (int channels = 16; channels < 256; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + + return problems; + +} + +std::vector DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2() { + std::vector problems; + for (int channels = 16; channels < 256; channels += 16) { + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 16, // split_k_slices + channels // groups + )); + } + + return problems; +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x32_4_8x32_Filter3x3_Stride1x1_Dilation1x1) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 32; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<3, 3>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 4; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; + using StrideShape = cutlass::MatrixShape<1, 1>; + using DilationShape = cutlass::MatrixShape<1, 1>; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter3x3_stride1x1_dilation1x1())); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x32_4_8x32_Filter3x3_Stride2x2_Dilation2x2) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 32; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<3, 3>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<8, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 4; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; + using StrideShape = cutlass::MatrixShape<2, 2>; + using DilationShape = cutlass::MatrixShape<2, 2>; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter3x3_stride2x2_dilation2x2())); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x64_3_16x64_Filter5x5_Stride1x1_Dilation1x1) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 64; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<5, 5>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; + using StrideShape = cutlass::MatrixShape<1, 1>; + using DilationShape = cutlass::MatrixShape<1, 1>; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter5x5_stride1x1_dilation1x1())); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST( + SM60_Device_Depthwise_conv2d_Fprop_Direct_Conv_FixedStrideDilation_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x64_3_16x64_Filter5x5_Stride2x2_Dilation2x2) { + + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementComputeEpilogue = cutlass::half_t; + + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm60; + + // This code section describes the groups a thread block will compute + constexpr int groups_per_cta = 32; + + // This code section describes the output tile a thread block will compute + using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + + // This code section describes the filter shape + using FilterShape = cutlass::MatrixShape<5, 5>; + + // Threadblock tile shape + using ThreadblockShape = + cutlass::gemm::GemmShape; + + // This code section describes tile size a warp will computes + using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + + // This code section describe iterator algorithm selected is Analytic or Optimized + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; + using StrideShape = cutlass::MatrixShape<2, 2>; + using DilationShape = cutlass::MatrixShape<2, 2>; + + constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + // This code section describes the epilogue part of the kernel, we use default value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::Default>; + + using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + + using Direct2dConv = cutlass::conv::device::DirectConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificDepthwiseDirectConv2d( + DepthwiseFpropProblemSizes_filter5x5_stride2x2_dilation2x2())); +} diff --git a/test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu similarity index 99% rename from test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu rename to test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu index 16c93630..9b9e503d 100644 --- a/test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu +++ b/test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Tests for device-wide Implicit GEMM interface + \brief Tests for Depthwise Direct Conv interface */ #include "../../common/cutlass_unit_test.h" diff --git a/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu index 46b21647..1214aba3 100644 --- a/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu @@ -241,6 +241,155 @@ TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhw //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +// Optimized multistage singleGroup kernel +TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_64x64_64x3_32x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +// Optimized 2 stage singleGroup kernel +TEST(SM80_Device_Conv2d_Group_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_64x64_64x2_32x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index 8a2ab83a..ca6f50ff 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -31,6 +31,7 @@ cutlass_test_unit_add_executable( array.cu half.cu bfloat16.cu + float8.cu tfloat32.cu complex.cu quaternion.cu diff --git a/test/unit/core/float8.cu b/test/unit/core/float8.cu new file mode 100644 index 00000000..bbe17902 --- /dev/null +++ b/test/unit/core/float8.cu @@ -0,0 +1,103 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for basic float8 functionality +*/ + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(float_e4m3_t, host_conversion) { + for (int i = -8; i < 8; ++i) { + float f = static_cast(i); + + cutlass::float_e4m3_t x = static_cast(i); + cutlass::float_e4m3_t y = static_cast(f); + + EXPECT_TRUE(static_cast(x) == i); + EXPECT_TRUE(static_cast(y) == f); + } + + // Try out default-ctor (zero initialization of primitive proxy type) + EXPECT_TRUE(cutlass::float_e4m3_t() == 0.0_fe4m3); + + // Try out user-defined literals + EXPECT_TRUE(cutlass::float_e4m3_t(7) == 7_fe4m3); + EXPECT_TRUE(7 == static_cast(7_fe4m3)); +} + +TEST(float_e5m2_t, host_conversion) { + for (int i = -8; i < 8; ++i) { + float f = static_cast(i); + + cutlass::float_e5m2_t x = static_cast(i); + cutlass::float_e5m2_t y = static_cast(f); + + EXPECT_TRUE(static_cast(x) == i); + EXPECT_TRUE(static_cast(y) == f); + } + + // Try out default-ctor (zero initialization of primitive proxy type) + EXPECT_TRUE(cutlass::float_e5m2_t() == 0.0_fe5m2); + + // Try out user-defined literals + EXPECT_TRUE(cutlass::float_e5m2_t(7) == 7_fe5m2); + EXPECT_TRUE(7 == static_cast(7_fe5m2)); +} + +TEST(float_e4m3_t, host_arithmetic) { + for (int i = -4; i < 4; ++i) { + for (int j = -4; j < 4; ++j) { + + cutlass::float_e4m3_t x = static_cast(i); + cutlass::float_e4m3_t y = static_cast(j); + + EXPECT_TRUE(static_cast(x + y) == (i + j)); + } + } +} + +TEST(float_e5m2_t, host_arithmetic) { + for (int i = -4; i < 4; ++i) { + for (int j = -4; j < 4; ++j) { + + cutlass::float_e5m2_t x = static_cast(i); + cutlass::float_e5m2_t y = static_cast(j); + + EXPECT_TRUE(static_cast(x + y) == (i + j)); + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 6eb7fc74..697a0b29 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -47,10 +47,10 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Conversion template +/// Simple conversion function template __global__ void convert( - cutlass::Array *destination, + cutlass::Array *destination, cutlass::Array const *source) { cutlass::NumericArrayConverter convert; @@ -60,47 +60,9 @@ __global__ void convert( ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace core -} // namespace test - -///////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(NumericConversion, f32_to_f16_rn) { - - int const kN = 1; - using Source = float; - using Destination = cutlass::half_t; - - dim3 grid(1, 1); - dim3 block(1, 1); - - cutlass::HostTensor destination({1, kN}); - cutlass::HostTensor source({1, kN}); - - for (int i = 0; i < kN; ++i) { - source.host_data()[i] = float(i); - } - - source.sync_device(); - - test::core::kernel::convert<<< grid, block >>>( - reinterpret_cast *>(destination.device_data()), - reinterpret_cast const *>(source.device_data()) - ); - - destination.sync_host(); - - for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]); - } -} - -TEST(NumericConversion, f32x8_to_f16x8_rn) { - - int const kN = 8; - using Source = float; - using Destination = cutlass::half_t; +template +void run_test() { + const int kN = Count; dim3 grid(1, 1); dim3 block(1, 1); @@ -109,12 +71,12 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) { cutlass::HostTensor source({1, kN}); for (int i = 0; i < kN; ++i) { - source.host_data()[i] = float(i); + source.host_data()[i] = Source(i % 4); } source.sync_device(); - test::core::kernel::convert<<< grid, block >>>( + convert<<< grid, block >>>( reinterpret_cast *>(destination.device_data()), reinterpret_cast const *>(source.device_data()) ); @@ -122,70 +84,247 @@ TEST(NumericConversion, f32x8_to_f16x8_rn) { destination.sync_host(); for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == source.host_data()[i]); + EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); } } +} // namespace kernel +} // namespace core +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(NumericConversion, f32_to_f16_rn) { + int const kN = 1; + using Source = float; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, f32x8_to_f16x8_rn) { + int const kN = 8; + using Source = float; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// -TEST(NumericConversion, f16_to_f32_rn) { - +TEST(NumericConversion, f16_to_f32_rn) { int const kN = 1; using Source = cutlass::half_t; using Destination = float; - - dim3 grid(1, 1); - dim3 block(1, 1); - - cutlass::HostTensor destination({1, kN}); - cutlass::HostTensor source({1, kN}); - - for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(i); - } - - source.sync_device(); - - test::core::kernel::convert<<< grid, block >>>( - reinterpret_cast *>(destination.device_data()), - reinterpret_cast const *>(source.device_data()) - ); - - destination.sync_host(); - - for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); - } + test::core::kernel::run_test(); } TEST(NumericConversion, f16x8_to_f32x8_rn) { - int const kN = 8; using Source = cutlass::half_t; using Destination = float; + test::core::kernel::run_test(); +} - dim3 grid(1, 1); - dim3 block(1, 1); +///////////////////////////////////////////////////////////////////////////////////////////////// - cutlass::HostTensor destination({1, kN}); - cutlass::HostTensor source({1, kN}); +TEST(NumericConversion, f32_to_fe4m3_rn) { + int const kN = 1; + using Source = float; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} - for (int i = 0; i < kN; ++i) { - source.host_data()[i] = float(i); - } +TEST(NumericConversion, f32_to_fe4m3_rn_array) { + int const kN = 27; + using Source = float; + using Destination = cutlass::float_e4m3_t; - source.sync_device(); + test::core::kernel::run_test(); +} - test::core::kernel::convert<<< grid, block >>>( - reinterpret_cast *>(destination.device_data()), - reinterpret_cast const *>(source.device_data()) - ); +TEST(NumericConversion, f32_to_fe5m2_rn) { + int const kN = 1; + using Source = float; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} - destination.sync_host(); +TEST(NumericConversion, f32_to_fe5m2_rn_array) { + int const kN = 27; + using Source = float; + using Destination = cutlass::float_e5m2_t; - for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])); - } + test::core::kernel::run_test(); + +} + +TEST(NumericConversion, f16_to_fe4m3_rn) { + int const kN = 1; + using Source = cutlass::half_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, f16_to_fe4m3_rn_array) { + int const kN = 27; + using Source = cutlass::half_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, f16_to_fe5m2_rn) { + int const kN = 1; + using Source = cutlass::half_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, f16_to_fe5m2_rn_array) { + int const kN = 27; + using Source = cutlass::half_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, bf16_to_fe4m3_rn) { + int const kN = 1; + using Source = cutlass::bfloat16_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, bf16_to_fe4m3_rn_array) { + int const kN = 27; + using Source = cutlass::bfloat16_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, bf16_to_fe5m2_rn) { + int const kN = 1; + using Source = cutlass::bfloat16_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, bf16_to_fe5m2_rn_array) { + int const kN = 27; + using Source = cutlass::bfloat16_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(NumericConversion, fe4m3_to_fe5m2_rn) { + int const kN = 1; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_fe5m2_array) { + int const kN = 27; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::float_e5m2_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_fe4m3_rn) { + int const kN = 1; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_fe4m3_array) { + int const kN = 27; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::float_e4m3_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_f32_rn) { + int const kN = 1; + using Source = cutlass::float_e4m3_t; + using Destination = float; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_f32_array) { + int const kN = 27; + using Source = cutlass::float_e4m3_t; + using Destination = float; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_f32_rn) { + int const kN = 1; + using Source = cutlass::float_e5m2_t; + using Destination = float; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_f32_array) { + int const kN = 27; + using Source = cutlass::float_e5m2_t; + using Destination = float; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_f16_rn) { + int const kN = 1; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_f16_array) { + int const kN = 27; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_f16_rn) { + int const kN = 1; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_f16_array) { + int const kN = 27; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::half_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_bf16_rn) { + int const kN = 1; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe4m3_to_bf16_array) { + int const kN = 27; + using Source = cutlass::float_e4m3_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_bf16_rn) { + int const kN = 1; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test(); +} + +TEST(NumericConversion, fe5m2_to_bf16_array) { + int const kN = 27; + using Source = cutlass::float_e5m2_t; + using Destination = cutlass::bfloat16_t; + test::core::kernel::run_test(); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/epilogue/thread/activation.cu b/test/unit/epilogue/thread/activation.cu index 041fded4..1ec24dd5 100644 --- a/test/unit/epilogue/thread/activation.cu +++ b/test/unit/epilogue/thread/activation.cu @@ -34,6 +34,7 @@ #include "../../common/cutlass_unit_test.h" +#include "cutlass/layout/layout.h" #include "cutlass/epilogue/thread/activation.h" #include "cutlass/util/host_tensor.h" diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index f7977cda..bd649c84 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -68,7 +68,12 @@ cutlass_test_unit_add_executable( simt_sgemm_nt_sm80.cu simt_sgemm_tn_sm80.cu - + + simt_cgemm_nt_sm80.cu + simt_cgemm_tn_sm80.cu + + simt_f8gemm_tn_sm50.cu + simt_cgemm_nn_sm50.cu simt_cgemm_nt_sm50.cu simt_cgemm_tn_sm50.cu @@ -239,6 +244,13 @@ cutlass_test_unit_add_executable( gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm80.cu + # SM90 device level tests + gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu + gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu + gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu + gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu + gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu + gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu ) cutlass_test_unit_add_executable( @@ -430,7 +442,7 @@ cutlass_test_unit_add_executable( BATCH_SOURCES ON BATCH_SIZE 4 - ## SYRK + ## SYRK # Syrk SM80 f64 tests syrk_f64n_f64t_tensor_op_f64_sm80.cu syrk_f64t_f64n_tensor_op_f64_sm80.cu @@ -452,6 +464,12 @@ cutlass_test_unit_add_executable( syrk_cf32n_cf32t_tensor_op_fast_f32_sm80.cu syrk_cf32n_cf32n_tensor_op_fast_f32_sm80.cu + # Syrk SM90 f64 tests + syrk_f64_f64_tensor_op_f64_sm90.cu + + # Syrk SM90 complex f64 tests + syrk_cf64_cf64_tensor_op_f64_sm90.cu + ## HERK # Herk SM80 complex f64 tests herk_cf64h_cf64n_tensor_op_f64_sm80.cu @@ -460,6 +478,9 @@ cutlass_test_unit_add_executable( herk_cf32h_cf32n_tensor_op_f32_sm80.cu herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu + # Herk SM90 complex f64 tests + herk_cf64_cf64_tensor_op_f64_sm90.cu + ## TRMM # Trmm SM80 f64 tests trmm_f64n_f64n_f64t_tensor_op_f64_ls_sm80.cu @@ -486,6 +507,12 @@ cutlass_test_unit_add_executable( trmm_cf32n_cf32n_cf32t_tensor_op_f32_sm80.cu trmm_cf32n_cf32n_cf32t_tensor_op_fast_f32_sm80.cu + # Trmm SM90 f64 tests + trmm_f64_f64_f64_tensor_op_f64_sm90.cu + + # Trmm SM90 complex f64 tests + trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu + ## SYR2K # Syr2k SM80 f64 tests syr2k_f64n_f64t_tensor_op_f64_sm80.cu @@ -508,6 +535,12 @@ cutlass_test_unit_add_executable( syr2k_cf32n_cf32n_tensor_op_fast_f32_sm80.cu syr2k_cf32n_cf32t_tensor_op_fast_f32_sm80.cu + # Syr2k SM90 f64 tests + syr2k_f64_f64_tensor_op_f64_sm90.cu + + # Syr2k SM90 complex f64 tests + syr2k_cf64_cf64_tensor_op_f64_sm90.cu + ## HER2K # Her2k SM80 complex f64 tests her2k_cf64n_cf64n_tensor_op_f64_sm80.cu @@ -516,6 +549,9 @@ cutlass_test_unit_add_executable( her2k_cf32h_cf32n_tensor_op_f32_sm80.cu her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu + # Her2k SM90 complex f64 tests + her2k_cf64_cf64_tensor_op_f64_sm90.cu + ## SYMM # Symm SM80 f64 tests symm_f64n_f64n_tensor_op_f64_ls_sm80.cu @@ -546,6 +582,12 @@ cutlass_test_unit_add_executable( symm_cf32n_cf32n_tensor_op_fast_f32_ls_sm80.cu symm_cf32n_cf32n_tensor_op_fast_f32_rs_sm80.cu + # Symm SM90 f64 tests + symm_f64_f64_tensor_op_f64_sm90.cu + + # Symm SM90 complex f64 tests + symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu + # Hemm SM80 complex f64 tests hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu @@ -556,7 +598,10 @@ cutlass_test_unit_add_executable( hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu - ) + + # Hemm SM90 complex f64 tests + hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +) cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_grouped_blas3 @@ -582,3 +627,13 @@ cutlass_test_unit_add_executable( ) endif() + +if (NOT CUDA_COMPILER MATCHES "[Cc]lang") + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_broadcast + + gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu +) + +endif() diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu index f960dfcc..93d677d7 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu @@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { @@ -93,7 +93,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { @@ -115,7 +115,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { @@ -137,7 +137,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { @@ -159,7 +159,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { @@ -181,7 +181,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { @@ -203,7 +203,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { @@ -225,7 +225,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu index 59e1e367..6cb761bc 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu @@ -70,7 +70,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { @@ -90,7 +90,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { @@ -111,7 +111,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { @@ -131,7 +131,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { @@ -151,7 +151,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { @@ -171,7 +171,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { @@ -191,7 +191,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { @@ -211,7 +211,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { @@ -231,7 +231,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { @@ -251,7 +251,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { @@ -271,7 +271,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { @@ -291,7 +291,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { @@ -311,7 +311,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { @@ -331,7 +331,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { @@ -351,7 +351,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { @@ -371,7 +371,7 @@ TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu index e6368203..99596dcb 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu @@ -83,7 +83,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8 cutlass::arch::OpXorPopc >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { @@ -114,7 +114,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { @@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { @@ -176,7 +176,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { @@ -207,7 +207,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { @@ -238,6 +238,6 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } #endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu index c1a32d9d..e83d4fe5 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu @@ -71,7 +71,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) { @@ -93,7 +93,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) { @@ -115,7 +115,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x128x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) { @@ -137,7 +137,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x256x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) { @@ -159,7 +159,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 256x64x512_64x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) { @@ -180,7 +180,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x128x512_32x64x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) { @@ -202,7 +202,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x64x512_64x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { @@ -224,7 +224,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu index f359db81..f9398d64 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu @@ -83,7 +83,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x256x512_64x64x512_8x8 cutlass::arch::OpXorPopc >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8x128) { @@ -114,7 +114,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 256x128x512_64x64x512_8x8 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8x128) { @@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x128x512_64x64x512_8x8 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x128) { @@ -176,7 +176,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x128x512_32x64x512_8x8x 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x128) { @@ -207,7 +207,7 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 128x64x512_64x32x512_8x8x 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x128) { @@ -238,6 +238,6 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1 2, 128, 128, false, cutlass::arch::OpXorPopc>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } #endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED diff --git a/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu index 33a4afd2..44908759 100644 --- a/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_bf16t_bf16t_bf16t_tensor_op_f32_sm80.cu @@ -65,7 +65,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x64_64x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) { @@ -83,7 +83,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x64_64x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) { @@ -101,7 +101,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x64_64x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) { @@ -119,7 +119,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x64_64x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) { @@ -137,7 +137,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x64_64x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) { @@ -155,7 +155,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x64_32x64x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) { @@ -173,7 +173,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x64_64x32x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) { @@ -191,7 +191,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x64_32x32x64) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) { @@ -209,7 +209,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x256x32_64x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) { @@ -227,7 +227,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x128x32_64x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) { @@ -245,7 +245,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x128x32_64x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) { @@ -263,7 +263,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 256x64x32_64x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) { @@ -281,7 +281,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x256x32_64x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) { @@ -299,7 +299,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x128x32_32x64x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) { @@ -317,7 +317,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 128x64x32_64x32x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) { @@ -335,7 +335,7 @@ TEST(SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32, 64x64x32_32x32x32) { ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu index 722d0d98..54bb64ba 100644 --- a/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu @@ -56,7 +56,7 @@ // Operands data type: complex // Rounding: float -> tfloat32_t (half_ulp_truncate) // Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) // Output data type: complex ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu index 1b263a4c..dcb81f95 100644 --- a/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu @@ -56,7 +56,7 @@ // Operands data type: complex // Rounding: float -> tfloat32_t (round to nearest) // Instruction operand data type: tfloat32_t (real part) and tfloat32_t (imaginary part) -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Instruction output/accumulation data type: f32 (real part) and f32 (imaginary part) // Output data type: complex ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu new file mode 100644 index 00000000..4c247012 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 64x64x8_16x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..5c5537c9 --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_16x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_16x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<16, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x16_32x32x16) { + + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 64x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu new file mode 100644 index 00000000..3ead0b7a --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x8_32x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 16, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 64x64x16_32x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..220e65aa --- /dev/null +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" + + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 16, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x8_32x32x8) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x16_16x16x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 64x128x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu index ed40a285..785840b7 100644 --- a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu @@ -45,6 +45,7 @@ #include "cutlass/util/tensor_view_io.h" #include "testbed.h" +#include "testbed_universal.h" #if (CUTLASS_ARCH_MMA_SM80_SUPPORTED) @@ -104,6 +105,44 @@ CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x EXPECT_TRUE(test::gemm::device::TestAllGemm()); } ) +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64_sk, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} ) + +CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32n_tensor_op_f32, 128x128x64_64x64x64_sk, { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, cutlass::layout::ColumnMajor, + cutlass::half_t, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} ) + CUTLASS_TEST_L1(SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64, { using ElementOutput = float; using ElementAccumulator = float; diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu new file mode 100644 index 00000000..3a011c50 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu @@ -0,0 +1,440 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for GEMM + broadcast interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_elementwise.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +template +struct TestbedUtils { + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_D1; // Input D + cutlass::HostTensor tensor_D2; // Input D + cutlass::HostTensor tensor_Y1; // Input Y + cutlass::HostTensor tensor_Y2; // Input Y + cutlass::HostTensor tensor_Y_ref; + + // + // Methods + // + + TestbedUtils( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize({1, problem_size.n()}); + tensor_D1.resize(problem_size.mn()); + tensor_D2.resize(problem_size.mn()); + tensor_Y1.resize(problem_size.mn()); + tensor_Y2.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // Initialize D data to smaller data range. This helps avoid large roundoff errors. + int d_scope_min = -2; + int d_scope_max = 2; + cutlass::reference::host::TensorFillRandomUniform(tensor_D1.host_view(), seed + 2016, d_scope_max, d_scope_min, 0); + cutlass::reference::host::TensorFillRandomUniform(tensor_D2.host_view(), seed + 2015, d_scope_max, d_scope_min, 0); + + EXPECT_TRUE(initialize_tensor(tensor_Y1.host_view(), cutlass::Distribution::AllZeros, 0)); + EXPECT_TRUE(initialize_tensor(tensor_Y2.host_view(), cutlass::Distribution::AllZeros, 0)); + EXPECT_TRUE(initialize_tensor(tensor_Y_ref.host_view(), cutlass::Distribution::AllZeros, 0)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = GemmElement(1); + tensor_B.host_view().at({0, 0}) = GemmElement(1); + tensor_C.host_view().at({0, 0}) = GemmElement(1); + tensor_D1.host_view().at({0, 0}) = GemmElement(1); + tensor_D2.host_view().at({0, 0}) = GemmElement(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, cutlass::HostTensor& tensor_Y_ref, cutlass::HostTensor& tensor_Y) { + + tensor_Y_ref.sync_host(); + tensor_Y.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y_ref.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Y.host_view()), 0); + + bool passed = true; + float norm_diff = 0; + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Y_ref.host_view(), tensor_Y.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Y is incorrect"; + + + if (!passed) { + std::ofstream file("errors_testbed_gemm_broadcast_new.txt"); + + + file + << "problem: " << problem_size << "\n\n"; + + file + << "capacity: \n" + << "A: " << tensor_A.capacity() + << "\nB: " << tensor_B.capacity() + << "\nC: " << tensor_C.capacity() + << "\nD1: " << tensor_D1.capacity() + << "\nD2: " << tensor_D2.capacity() + << "\nY: " << tensor_Y.capacity() + << "\n\n" + << "\nY_ref: " << tensor_Y_ref.capacity() + << "\n\n"; + file + << "A =\n" << tensor_A.host_view() + << "\n\nB =\n" << tensor_B.host_view() + << "\n\nC =\n" << tensor_C.host_view() + << "\n\nD1 =\n" << tensor_D1.host_view() + << "\n\nD2 =\n" << tensor_D2.host_view() + << "\n\nY =\n" << tensor_Y.host_view() + << "\n\nY_ref =\n" << tensor_Y_ref.host_view(); + } + + return passed; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +TEST(SM80_Device_GemmWithBroadcast_f16t_f16n_f16t_tensor_op_f16, 128x128_32x3_64x64x32_16x8x16) { + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using OpClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + const int kStages = 3; + + const int batch_count = 1; + const cutlass::half_t alpha(1); + const cutlass::half_t beta(1); + + const int M = 1024; + const int K = 10240; + const int N = 512; + cutlass::gemm::GemmCoord problem{M, N, K}; + + const int batch_stride_A = 0; + const int batch_stride_B = 0; + const int batch_stride_C1 = 0; + const int batch_stride_C2 = 0; + const int batch_stride_D = 0; + const int batch_stride_Vector = 0; + const int batch_stride_Tensor = 0; + + const int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); + const int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); + const int64_t ldc1 = LayoutC::packed({problem.m(), problem.n()}).stride(0); + const int64_t ldc2 = LayoutC::packed({problem.m(), problem.n()}).stride(0); + const int64_t ldd = LayoutC::packed({problem.m(), problem.n()}).stride(0); + const int64_t ldv = 0; + const int64_t ldt = 0; + + TestbedUtils utils; + utils.initialize(problem); + + // + // Create reference Gemm + // + using GemmRef = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, + OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ThreadblockSwizzle, kStages>; + + typename GemmRef::Arguments args_ref{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + {alpha, beta}, + utils.tensor_A.device_data(), + utils.tensor_B.device_data(), + utils.tensor_C.device_data(), + utils.tensor_Y_ref.device_data(), + batch_stride_A, + batch_stride_B, + batch_stride_C1, + batch_stride_D, + lda, + ldb, + ldv, + ldd, + }; + + GemmRef gemm_op_ref; + size_t workspace_size_ref = GemmRef::get_workspace_size(args_ref); + cutlass::device_memory::allocation workspace_ref(workspace_size_ref); + cutlass::Status status = gemm_op_ref.initialize(args_ref, workspace_ref.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + status = gemm_op_ref(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + // + // Create GemmWithBroadcast from single source + // + using GemmSingle = cutlass::gemm::device::GemmUniversalWithBroadcast< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, + OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, + cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, ElementAccumulator, ElementAccumulator, + ElementAccumulator, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity>, + ThreadblockSwizzle, kStages>; + + typename GemmSingle::Arguments args_single{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + {alpha, beta}, + utils.tensor_A.device_data(), + utils.tensor_B.device_data(), + utils.tensor_D1.device_data(), + utils.tensor_Y1.device_data(), + utils.tensor_C.device_data(), + /* ptr_Tensor = */ nullptr, + batch_stride_A, + batch_stride_B, + batch_stride_C1, + batch_stride_D, + batch_stride_Vector, + batch_stride_Tensor, + lda, + ldb, + ldc1, + ldd, + ldv, + ldt + }; + + GemmSingle gemm_op_single; + size_t workspace_size_single = GemmSingle::get_workspace_size(args_single); + cutlass::device_memory::allocation workspace_single(workspace_size_single); + status = gemm_op_single.initialize(args_single, workspace_single.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + status = gemm_op_single(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + // Compute the broadcast on the reference previously computed and compare results + utils.tensor_Y_ref.sync_host(); + cutlass::reference::host::TensorMul(utils.tensor_Y_ref.host_view(), utils.tensor_D1.host_view()); + utils.tensor_Y_ref.sync_device(); + utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y1); + + // + // Create GemmWithBroadcast from two sources + // + using GemmDouble = cutlass::gemm::device::GemmUniversalWithBroadcast< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, ElementAccumulator, + OpClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, + cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, ElementAccumulator, ElementAccumulator, + ElementAccumulator, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::thread::Identity, cutlass::multiplies, cutlass::epilogue::thread::Identity, cutlass::plus>, + ThreadblockSwizzle, kStages>; + + typename GemmDouble::Arguments args_double{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + {alpha, beta}, + utils.tensor_A.device_data(), + utils.tensor_B.device_data(), + utils.tensor_D1.device_data(), + utils.tensor_D2.device_data(), + utils.tensor_Y2.device_data(), + utils.tensor_C.device_data(), + /* ptr_Tensor = */ nullptr, + batch_stride_A, + batch_stride_B, + batch_stride_C1, + batch_stride_C2, + batch_stride_D, + batch_stride_Vector, + batch_stride_Tensor, + lda, + ldb, + ldc1, + ldc2, + ldd, + ldv, + ldt + }; + + GemmDouble gemm_op_double; + size_t workspace_size_double = GemmDouble::get_workspace_size(args_double); + cutlass::device_memory::allocation workspace_double(workspace_size_double); + status = gemm_op_double.initialize(args_double, workspace_double.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + status = gemm_op_double(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << cutlassGetStatusString(status); + + // Compute the broadcast on the reference previously computed and compare results + utils.tensor_Y_ref.sync_host(); + cutlass::reference::host::TensorAdd(utils.tensor_Y_ref.host_view(), utils.tensor_D2.host_view()); + utils.tensor_Y_ref.sync_device(); + utils.compare_reference(problem, utils.tensor_Y_ref, utils.tensor_Y2); +} + +#endif diff --git a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..650a6d3f --- /dev/null +++ b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::ColumnMajor, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..99f0c337 --- /dev/null +++ b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with Hopper FP64 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x64x16_32x32x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 64x128x16_32x64x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x64x16_64x32x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 128x128x16_32x64x16_16x8x4) { + + using ElementOutput = double; + using ElementAccumulator = double; + using ElementCompute = double; + + using Gemm = cutlass::gemm::device::Gemm< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if (CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu index 9f0b16a1..bbdf60c8 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu @@ -81,7 +81,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x256x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { @@ -113,7 +113,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x128x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { @@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x128x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) { @@ -177,7 +177,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x256x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) { @@ -209,7 +209,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 256x64x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { @@ -241,7 +241,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x128x128_32x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { @@ -273,7 +273,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 128x64x128_64x32x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { @@ -305,7 +305,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4n_tensor_op_s32, 64x64x128_32x32x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu index 1a405b7a..28471dbc 100644 --- a/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu @@ -81,7 +81,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x256x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { @@ -113,7 +113,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x128x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { @@ -145,7 +145,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x128x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) { @@ -177,7 +177,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x256x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) { @@ -209,7 +209,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 256x64x128_64x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { @@ -240,7 +240,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x128x128_32x64x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { @@ -272,7 +272,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 128x64x128_64x32x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { @@ -304,7 +304,7 @@ TEST(SM75_Device_Gemm_s4t_s4n_s4t_tensor_op_s32, 64x64x128_32x32x128) { 2 >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..9a5abfb0 --- /dev/null +++ b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide HEMM interface + + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/symm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/symm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_symm_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Hemm = cutlass::gemm::device::Symm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kLeft, + cutlass::FillMode::kLower, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::BlasMode::kHermitian + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Hemm_cf64h_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Hemm = cutlass::gemm::device::Symm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kRight, + cutlass::FillMode::kUpper, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..b577cfcb --- /dev/null +++ b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide HER2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank2k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + 1, // AlignmentB + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::BlasMode::kHermitian + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Her2k_cf64c_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + 1, // AlignmentB + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate, + cutlass::BlasMode::kHermitian + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KHermitianUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..46604381 --- /dev/null +++ b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide HERK interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank_k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// HERK operator on CUBLAS_OP_C (row-major + conj) input layouts +TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using RankK = cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate, + cutlass::BlasMode::kHermitian + >; + + EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h index dbc77b9d..1cbbb3db 100644 --- a/test/unit/gemm/device/multistage_testbed.h +++ b/test/unit/gemm/device/multistage_testbed.h @@ -125,7 +125,7 @@ struct MultistageTestbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/simt_cgemm_nt_sm80.cu b/test/unit/gemm/device/simt_cgemm_nt_sm80.cu new file mode 100644 index 00000000..fe1092c3 --- /dev/null +++ b/test/unit/gemm/device/simt_cgemm_nt_sm80.cu @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 32x64x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x64x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 64x128x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x64x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x128x8_64x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32n_cf32t_cf32t_simt_cf32, 128x256x8_64x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_cgemm_tn_sm80.cu b/test/unit/gemm/device/simt_cgemm_tn_sm80.cu new file mode 100644 index 00000000..89b68274 --- /dev/null +++ b/test/unit/gemm/device/simt_cgemm_tn_sm80.cu @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_complex.h" + + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 32x64x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x64x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 64x128x8_32x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x64x8_64x32x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x128x8_64x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +TEST(SM80_Device_Gemm_cf32t_cf32n_cf32t_simt_cf32, 128x256x8_64x64x1) { + + using Element = cutlass::complex; + + using Gemm = cutlass::gemm::device::GemmComplex< + Element, + cutlass::layout::RowMajor, + Element, + cutlass::layout::ColumnMajor, + Element, + cutlass::layout::RowMajor, + Element, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + Element, + 1, + Element, + Element>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmComplex()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu b/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu new file mode 100644 index 00000000..4e362197 --- /dev/null +++ b/test/unit/gemm/device/simt_f8gemm_tn_sm50.cu @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 11) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) + +TEST(SM50_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_simt_f32, 32x64x8_32x64x1) { + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementC = cutlass::float_e4m3_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementC, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm50, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementC>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/simt_qgemm_nn_sm50.cu b/test/unit/gemm/device/simt_qgemm_nn_sm50.cu index 0db0240d..20477698 100644 --- a/test/unit/gemm/device/simt_qgemm_nn_sm50.cu +++ b/test/unit/gemm/device/simt_qgemm_nn_sm50.cu @@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) diff --git a/test/unit/gemm/device/simt_qgemm_nt_sm50.cu b/test/unit/gemm/device/simt_qgemm_nt_sm50.cu index 36fd261d..ac81f438 100644 --- a/test/unit/gemm/device/simt_qgemm_nt_sm50.cu +++ b/test/unit/gemm/device/simt_qgemm_nt_sm50.cu @@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_nt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_nt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_nt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) diff --git a/test/unit/gemm/device/simt_qgemm_tn_sm50.cu b/test/unit/gemm/device/simt_qgemm_tn_sm50.cu index 191b057a..30159697 100644 --- a/test/unit/gemm/device/simt_qgemm_tn_sm50.cu +++ b/test/unit/gemm/device/simt_qgemm_tn_sm50.cu @@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tn, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tn, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tn, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tn, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) diff --git a/test/unit/gemm/device/simt_qgemm_tt_sm50.cu b/test/unit/gemm/device/simt_qgemm_tt_sm50.cu index 1b531f98..02e23702 100644 --- a/test/unit/gemm/device/simt_qgemm_tt_sm50.cu +++ b/test/unit/gemm/device/simt_qgemm_tt_sm50.cu @@ -76,7 +76,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x32x8_8x32x1_2x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -106,7 +106,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tt, 16x32x8_16x32x1_4x4_4x8_1x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -136,7 +136,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 8x32x8_8x16x1_2x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -166,7 +166,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 8x64x8_8x32x1_2x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -196,7 +196,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x32x8_16x16x1_4x2_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -226,7 +226,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 16x64x8_16x32x1_4x4_4x8_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -256,7 +256,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x32x8_32x16x1_4x4_8x4_1x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -286,7 +286,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x32x1_4x4_4x8_2x1, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -316,7 +316,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x32x8_8x16x1_2x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -346,7 +346,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x8_8x32x1_2x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -376,7 +376,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x16x1_4x2_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -406,7 +406,7 @@ CUTLASS_TEST_L0(SM50_device_qgemm_tt, 32x64x8_16x32x1_4x4_4x8_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -436,7 +436,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x32x8_32x16x1_4x4_8x4_2x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -466,7 +466,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 16x64x16_8x16x1_2x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_16x8x1_2x2_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -526,7 +526,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x64x8_16x16x1_4x2_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -556,7 +556,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 32x128x8_16x32x1_4x4_4x8_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -586,7 +586,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x64x8_32x16x1_4x4_8x4_2x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -616,7 +616,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x32x8_8x16x1_2x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -646,7 +646,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x8_16x16x1_4x2_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -676,7 +676,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x32x1_4x4_4x8_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -706,7 +706,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x32x8_32x16x1_4x4_8x4_4x2, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -736,7 +736,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 32x64x16_8x16x1_2x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -766,7 +766,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x32x16_16x8x1_2x2_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -796,7 +796,7 @@ CUTLASS_TEST_L2(SM50_device_qgemm_tt, 64x64x8_16x16x1_4x2_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -826,7 +826,7 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 64x128x8_16x32x1_4x4_4x8_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) //////////////////////////////////////////////////////////////////////////////// @@ -856,6 +856,6 @@ CUTLASS_TEST_L1(SM50_device_qgemm_tt, 128x64x8_32x16x1_4x4_8x4_4x4, { cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 2 // Stages >; - EXPECT_TRUE(test::gemm::device::TestAllGemm()); + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } ) diff --git a/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..544c1520 --- /dev/null +++ b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYMM interface + + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/symm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/symm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_symm_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Symm = cutlass::gemm::device::Symm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kLeft, + cutlass::FillMode::kLower, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddGaussianComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Symm_cf64n_cf64n_rs_u_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Symm = cutlass::gemm::device::Symm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kRight, + cutlass::FillMode::kUpper, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddComplex + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..5dca4868 --- /dev/null +++ b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/symm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/symm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_symm_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Symm = cutlass::gemm::device::Symm< + ElementA, + LayoutA, + cutlass::SideMode::kRight, + cutlass::FillMode::kLower, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Symm = cutlass::gemm::device::Symm< + ElementA, + LayoutA, + cutlass::SideMode::kLeft, + cutlass::FillMode::kLower, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSymmUniversal()); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..7c886030 --- /dev/null +++ b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYRK interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank2k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + 1, // AlignmentB + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::BlasMode::kSymmetric + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syr2k_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + 1, // AlignmentB + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + cutlass::BlasMode::kSymmetric + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..f72ee797 --- /dev/null +++ b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYRK interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank2k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syr2k_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2K = cutlass::gemm::device::Rank2K< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllRank2KUniversal()); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..3f393d64 --- /dev/null +++ b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYRK interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank_k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using RankK = cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + false, // SplitKSerial + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kNone, + cutlass::BlasMode::kSymmetric + >; + + EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syrk_cf64n_cf64t_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using RankK = cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // kStages + 1, // AlignmentA + false, // SplitKSerial + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kNone, + cutlass::BlasMode::kSymmetric + >; + + EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..d8d8687a --- /dev/null +++ b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu @@ -0,0 +1,126 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide SYRK interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_rank_k_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using RankK = cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Syrk_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using RankK = cutlass::gemm::device::RankK< + ElementA, + LayoutA, + ElementC, + LayoutC, + cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllRankKUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index 9b33e0a2..9d2a729d 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -50,9 +50,11 @@ #include "cutlass/util/reference/host/gemm.h" #include "testbed_utils.h" +#include "testbed_universal.h" #include "cutlass/layout/matrix.h" #include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" namespace test { namespace gemm { @@ -309,7 +311,7 @@ struct Testbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -319,10 +321,19 @@ struct Testbed { /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size, + cutlass::gemm::GemmCoord problem_size, int split_k_slices = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "split_k_slices: " << split_k_slices << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ // Waive test if insufficient CUDA device if (!sufficient()) { @@ -387,7 +398,7 @@ struct Testbed { ///////////////////////////////////////////////////////////////////////////////////////////////// template -bool TestAllGemm( +bool TestAllGemmBasic( const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { @@ -477,6 +488,52 @@ bool TestAllGemm( return passed; } +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A, + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) +{ + // Test basic GEMM with non-default stride factors + return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); +} + +template +bool TestAllGemm() +{ +#ifdef NDEBUG + // Non-debug builds also test basic GEMM with default stride factors + if (!TestAllGemmBasic()) { + return false; + } +#endif // NDEBUG + + // Test universal GEMM +#if 0 + // Define the universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle + >; +#else + // Define the streamk universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle + >; +#endif + + // Define the universal adaptor + using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Test universal GEMM + return TestAllGemmUniversal(); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// template bool TestGemmPerf(int iterations = 1) { diff --git a/test/unit/gemm/device/testbed_complex.h b/test/unit/gemm/device/testbed_complex.h index e6893026..aa3eef15 100644 --- a/test/unit/gemm/device/testbed_complex.h +++ b/test/unit/gemm/device/testbed_complex.h @@ -128,7 +128,7 @@ struct TestbedComplex : public Testbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index 242736ad..58528b0e 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -388,7 +388,7 @@ struct TestbedGemmWithBroadcast { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_gemm_with_reduction.h b/test/unit/gemm/device/testbed_gemm_with_reduction.h index e51eed20..366a8209 100644 --- a/test/unit/gemm/device/testbed_gemm_with_reduction.h +++ b/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -375,7 +375,7 @@ struct TestbedGemmWithReduction { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_grouped_scheduler.h b/test/unit/gemm/device/testbed_grouped_scheduler.h index 0bd409e0..5bb3e64b 100644 --- a/test/unit/gemm/device/testbed_grouped_scheduler.h +++ b/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -312,7 +312,8 @@ template struct TestbedGroupedGemmScheduler { - using BaselinePV = BaselineProblemVisitor, + using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; + using BaselinePV = BaselineProblemVisitor; diff --git a/test/unit/gemm/device/testbed_interleaved.h b/test/unit/gemm/device/testbed_interleaved.h index ce4e2c9f..77de5b16 100644 --- a/test/unit/gemm/device/testbed_interleaved.h +++ b/test/unit/gemm/device/testbed_interleaved.h @@ -130,7 +130,7 @@ struct InterleavedTestbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_planar_complex.h b/test/unit/gemm/device/testbed_planar_complex.h index 52aa6982..e041232b 100644 --- a/test/unit/gemm/device/testbed_planar_complex.h +++ b/test/unit/gemm/device/testbed_planar_complex.h @@ -140,7 +140,7 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_rank2k_universal.h b/test/unit/gemm/device/testbed_rank2k_universal.h index 4442d99f..13fbe257 100644 --- a/test/unit/gemm/device/testbed_rank2k_universal.h +++ b/test/unit/gemm/device/testbed_rank2k_universal.h @@ -298,7 +298,7 @@ struct TestbedRank2KUniversal { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_rank_k_universal.h b/test/unit/gemm/device/testbed_rank_k_universal.h index d4a946c5..644703b8 100644 --- a/test/unit/gemm/device/testbed_rank_k_universal.h +++ b/test/unit/gemm/device/testbed_rank_k_universal.h @@ -286,7 +286,7 @@ struct TestbedRank2KUniversal { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index 1509d9dd..885f0d1d 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -323,7 +323,7 @@ struct SparseTestbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_splitk.h b/test/unit/gemm/device/testbed_splitk.h index fcc136c1..da14e2b2 100644 --- a/test/unit/gemm/device/testbed_splitk.h +++ b/test/unit/gemm/device/testbed_splitk.h @@ -88,7 +88,7 @@ struct TestbedSplitK : public Testbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_symm_universal.h b/test/unit/gemm/device/testbed_symm_universal.h index 14218d64..d91ca29c 100644 --- a/test/unit/gemm/device/testbed_symm_universal.h +++ b/test/unit/gemm/device/testbed_symm_universal.h @@ -324,7 +324,7 @@ struct TestbedSymmUniversal { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_trmm_universal.h b/test/unit/gemm/device/testbed_trmm_universal.h index 13c3c44f..e540e6ba 100644 --- a/test/unit/gemm/device/testbed_trmm_universal.h +++ b/test/unit/gemm/device/testbed_trmm_universal.h @@ -364,7 +364,7 @@ struct TestbedTrmmUniversal { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index ca85e062..8de39e01 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -58,7 +58,7 @@ namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct TestbedUniversal { using ElementAccumulator = typename Gemm::ElementAccumulator; @@ -158,9 +158,10 @@ struct TestbedUniversal { // It is possible to randomly initialize to all zeros, so override this with non-zeros // in the upper left corner of each operand. - tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); - tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); - tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); @@ -253,6 +254,17 @@ struct TestbedUniversal { ElementAccumulator(0) ); + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + return compare_reference(problem_size, alpha, beta); } @@ -278,7 +290,7 @@ struct TestbedUniversal { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } @@ -288,10 +300,20 @@ struct TestbedUniversal { /// Executes one test bool run( cutlass::gemm::GemmUniversalMode mode, - cutlass::gemm::GemmCoord problem_size, + cutlass::gemm::GemmCoord problem_size, int batch_count = 1, - ElementCompute alpha = ElementCompute(1), - ElementCompute beta = ElementCompute(0)) { + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "mode: " << (int) mode << "\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "batch_count: " << batch_count << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ // Waive test if insufficient CUDA device if (!sufficient()) { @@ -359,7 +381,7 @@ struct TestbedUniversal { }; ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template bool TestGemmUniversal( cutlass::gemm::GemmCoord const & problem_size, cutlass::gemm::GemmUniversalMode mode, @@ -369,7 +391,7 @@ bool TestGemmUniversal( bool passed = true; - TestbedUniversal testbed; + TestbedUniversal testbed; using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; @@ -384,7 +406,7 @@ bool TestGemmUniversal( return passed; } -template +template bool TestAllGemmUniversal() { bool passed = true; @@ -412,9 +434,9 @@ bool TestAllGemmUniversal() { cutlass::platform::is_same::value && (cutlass::platform::is_same::value || cutlass::platform::is_same::value) ? 4 : kAlignment; - - - + + + cutlass::gemm::GemmUniversalMode modes[] = { cutlass::gemm::GemmUniversalMode::kGemm, }; @@ -428,8 +450,8 @@ bool TestAllGemmUniversal() { }; int problem_size_k[] = { - kAlignmentK, - Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK }; @@ -468,7 +490,7 @@ bool TestAllGemmUniversal() { cutlass::gemm::GemmCoord problem_size(m, n, k); - TestbedUniversal testbed; + TestbedUniversal testbed; passed = testbed.run( mode, diff --git a/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..5aab4899 --- /dev/null +++ b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_trmm_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Trmm = cutlass::gemm::device::Trmm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kLeft, + cutlass::FillMode::kUpper, + cutlass::DiagType::kNonUnit, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddGaussianComplex, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Trmm_cf64h_cf64n_cf64t_ls_u_nu_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = cutlass::complex; + using ElementAccumulator = cutlass::complex; + + using Trmm = cutlass::gemm::device::Trmm< + cutlass::complex, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kLeft, + cutlass::FillMode::kUpper, + cutlass::DiagType::kNonUnit, + cutlass::complex, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + 1, + 1, + false, + cutlass::arch::OpMultiplyAddComplex, + cutlass::ComplexTransform::kConjugate + >; + + EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu new file mode 100644 index 00000000..85fbc9b2 --- /dev/null +++ b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_trmm_universal.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Trmm = cutlass::gemm::device::Trmm< + double, + cutlass::layout::ColumnMajor, + cutlass::SideMode::kRight, + cutlass::FillMode::kLower, + cutlass::DiagType::kNonUnit, + double, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Trmm_f64t_f64t_f64n_rs_l_nu_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementOutput = double; + using ElementAccumulator = double; + + using Trmm = cutlass::gemm::device::Trmm< + double, + cutlass::layout::RowMajor, + cutlass::SideMode::kRight, + cutlass::FillMode::kLower, + cutlass::DiagType::kNonUnit, + double, + cutlass::layout::RowMajor, + ElementOutput, + cutlass::layout::ColumnMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm90, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 4>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 1, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h index 7ae54704..0a89ced8 100644 --- a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -257,7 +257,7 @@ struct SparseTestbed { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - if (properties.sharedMemPerMultiprocessor < smem_size) { + if (properties.sharedMemPerBlockOptin < smem_size) { return false; } diff --git a/test/unit/gemm/warp/CMakeLists.txt b/test/unit/gemm/warp/CMakeLists.txt index acef876d..efdc3711 100644 --- a/test/unit/gemm/warp/CMakeLists.txt +++ b/test/unit/gemm/warp/CMakeLists.txt @@ -37,6 +37,8 @@ cutlass_test_unit_add_executable( gemm_complex_sm80.cu gemm_sparse_sm80.cu gemm_gaussian_complex_sm80.cu + gemm_sm90.cu + gemm_complex_sm90.cu wmma_sm70.cu wmma_sm72.cu wmma_sm75.cu diff --git a/test/unit/gemm/warp/gemm_complex_sm80.cu b/test/unit/gemm/warp/gemm_complex_sm80.cu index d0312345..98d9f382 100644 --- a/test/unit/gemm/warp/gemm_complex_sm80.cu +++ b/test/unit/gemm/warp/gemm_complex_sm80.cu @@ -56,7 +56,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// // complex * complex => complex // Input data type: complex -// Math instruction: MMA.884.F64.F64 +// Math instruction: mma.sync.aligned.m8n8k4.f64.f64.f64.f64 // Output data type: complex /////////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM80_warp_gemm_complex_tensor_op_f64, 8x8x4_8x8x4_nt) { @@ -293,7 +293,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f64, 16x16x4_8x8x4_tn) { /////////////////////////////////////////////////////////////////////////////////////////////////// // complex * complex => complex // Input data type: complex -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Output data type: complex // Shared memory layout: Congrous //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -495,7 +495,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f32, 32x32x8_16x8x8_ct) { /////////////////////////////////////////////////////////////////////////////////////////////////// // complex * complex => complex // Input data type: complex -// Math instruction: MMA.1688.F32.TF32 +// Math instruction: mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 // Output data type: complex // Shared memory layout: Crosswise //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -526,7 +526,7 @@ TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x8_16x8x8_tn) { .run(); } -// TEST FAILS crosswise complex TN MMA.1688.F32.TF32 test fails for k = 2*8 = 16 +// TEST FAILS crosswise complex TN mma.sync.aligned.m16n8k8.f32.tf32.tf32.f32 test fails for k = 2*8 = 16 TEST(SM80_warp_gemm_complex_tensor_op_f32, 16x16x16_16x8x8_tn) { using Shape = cutlass::gemm::GemmShape<16, 16, 16>; diff --git a/test/unit/gemm/warp/gemm_complex_sm90.cu b/test/unit/gemm/warp/gemm_complex_sm90.cu new file mode 100644 index 00000000..0baf7a45 --- /dev/null +++ b/test/unit/gemm/warp/gemm_complex_sm90.cu @@ -0,0 +1,334 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM with Hopper FP64 +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_complex_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 8, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x32x4_16x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<16, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x16x4_16x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_nh) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kConjugate + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x4_16x8x4_ct) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous128b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous128b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kConjugate, + cutlass::ComplexTransform::kNone + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 8, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x16x4_16x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex >().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 32x32x16_16x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex().run(); +} + +TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<64, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise128x4; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TestbedComplex().run(); +} + +#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/warp/gemm_sm90.cu b/test/unit/gemm/warp/gemm_sm90.cu new file mode 100644 index 00000000..77e3dbcd --- /dev/null +++ b/test/unit/gemm/warp/gemm_sm90.cu @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM with Hopper FP64 +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) { + using Shape = cutlass::gemm::GemmShape<16, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x16x4_32x16x4_16x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 16, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x32x4_32x32x4_16x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_congruous_f64, 32x64x4_32x64x4_16x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 64, 4>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous64b; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous64b; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 16x16x16_16x16x16_16x8x4) { + using Shape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x32x16_32x32x16_16x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 64x32x16_64x32x16_16x8x4) { + using Shape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicand64bCrosswise; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd>::Type; + + test::gemm::warp::Testbed >() + .run(); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index 3638ca10..a1ea8f59 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -97,6 +97,11 @@ template struct ArchMap { static int const kMax = 1024; }; +template struct ArchMap { + static int const kMin = 90; + static int const kMax = 1024; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index b2e68097..e557084d 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -1445,6 +1445,9 @@ struct ConvArguments { /// pointer to implicit gemm matrix B void const *B; + /// pointer to reordered matrix B + void const *reordered_B; + /// pointer to implicit gemm matrix C void const *C; diff --git a/tools/library/scripts/conv2d_operation.py b/tools/library/scripts/conv2d_operation.py index 0ba4307f..23f8e6e6 100644 --- a/tools/library/scripts/conv2d_operation.py +++ b/tools/library/scripts/conv2d_operation.py @@ -17,7 +17,8 @@ from library import * class Conv2dOperation: # def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ - stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1): + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \ + group_mode = GroupMode.NoneGroup): self.operation_kind = OperationKind.Conv2d self.arch = arch @@ -31,6 +32,7 @@ class Conv2dOperation: self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor + self.group_mode = group_mode # def is_complex(self): complex_operators = [ @@ -95,17 +97,18 @@ class Conv2dOperation: opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] - threadblock = "%dx%d_%dx%d" % ( - self.tile_description.threadblock_shape[0], - self.tile_description.threadblock_shape[1], - self.tile_description.threadblock_shape[2], - self.tile_description.stages - ) + threadblock = self.tile_description.procedural_name() + + # grouped conv + if self.group_mode != GroupMode.NoneGroup: + group_conv_name = f"{GroupModeNames[self.group_mode]}_" + else: + group_conv_name = "" if self.stride_support == StrideSupport.Unity: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" else: - configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" return SubstituteTemplate( configuration_name, @@ -115,6 +118,7 @@ class Conv2dOperation: 'threadblock': threadblock, 'layout': self.layout_name(), 'alignment': "%d" % self.A.alignment, + 'group_conv_name': group_conv_name } ) @@ -162,7 +166,77 @@ class EmitConv2dInstance: ${align_b} >::Kernel; """ + self.template_group_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${group_mode}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_depthwise_direct_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>, + cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ${threadblock_output_shape_n}, + ${threadblock_output_shape_p}, + ${threadblock_output_shape_q}>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + cutlass::MatrixShape<${stride_r}, ${stride_s}>, + cutlass::MatrixShape<${dilation_r}, ${dilation_s}> + >::Kernel; +""" def emit(self, operation): @@ -206,7 +280,32 @@ class EmitConv2dInstance: 'align_b': str(operation.B.alignment), } - return SubstituteTemplate(self.template, values) + if operation.group_mode == GroupMode.NoneGroup: + return SubstituteTemplate(self.template, values) + + elif operation.group_mode == GroupMode.Depthwise: + values['group_mode'] = GroupModeTag[operation.group_mode] + # Setup other template params + values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0]) + values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1]) + values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2]) + + values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3]) + + values['filter_shape_r'] = str(operation.tile_description.filter_shape[0]) + values['filter_shape_s'] = str(operation.tile_description.filter_shape[1]) + + values['stride_r'] = str(operation.tile_description.stride[0]) + values['stride_s'] = str(operation.tile_description.stride[1]) + + values['dilation_r'] = str(operation.tile_description.dilation[0]) + values['dilation_s'] = str(operation.tile_description.dilation[1]) + + return SubstituteTemplate(self.template_depthwise_direct_conv, values) + + else: + values['group_mode'] = GroupModeTag[operation.group_mode] + return SubstituteTemplate(self.template_group_conv, values) ################################################################################################### # @@ -292,6 +391,16 @@ void initialize_${configuration_name}(Manifest &manifest) { Operation_${operation_name}>( "${operation_name}")); +""" + + self.configuration_direct_conv_instance = """ + using Operation_${operation_name} = cutlass::conv::device::DirectConvolution< + ${operation_name}>; + + manifest.append(new cutlass::library::DirectConv2dOperation< + Operation_${operation_name}>( + "${operation_name}")); + """ self.configuration_epilogue = """ @@ -334,10 +443,16 @@ void initialize_${configuration_name}(Manifest &manifest) { })) for operation in self.operations: - self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { - 'configuration_name': self.configuration_name, - 'operation_name': operation.procedural_name() - })) + if operation.group_mode == GroupMode.Depthwise: + self.configuration_file.write(SubstituteTemplate(self.configuration_direct_conv_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name() + })) + else: + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name() + })) self.configuration_file.write(self.configuration_epilogue) self.configuration_file.write(self.epilogue_template) diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index f0d84699..87777f5b 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -11,6 +11,8 @@ import argparse from library import * from manifest import * +from itertools import product + ################################################################################################### # @@ -49,6 +51,8 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ swizzling_functor = SwizzlingFunctor.Identity8): +# Use StreamK decomposition for basic GEMMs +# swizzling_functor = SwizzlingFunctor.StreamK): if complex_transforms is None: complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] @@ -373,11 +377,26 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme # Strided support for Analytic and Optimized Fprop for iterator_algorithm in iterator_algorithms: - new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ - A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) - - manifest.append(new_operation) - operations.append(new_operation) + new_operations = [ + # None grouped kernel + Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_), + ] + + # Instance group conv kernel + if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC: + # SingleGroup kernel + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) + + # Analytic iterator supports MultipleGroup mode + if iterator_algorithm == IteratorAlgorithm.Analytic: + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) + + for new_operation in new_operations: + manifest.append(new_operation) + operations.append(new_operation) # # Conv2d Dgrad @@ -593,6 +612,62 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme return operations +# Convolution for Depthwise 2d conv +def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # iterator algorithm (FixedStrideDilation, Optimized) + iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + if ConvKind.Fprop in conv_kinds: + + # Strided support for Optimized and FixedStridedDilation Depthwise Conv + for iterator_algorithm in iterator_algorithms: + stride_support = StrideSupport.Strided + if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation: + if tile.stride == [-1, -1] or tile.dilation == [-1,-1]: + continue + stride_support = StrideSupport.Fixed + + if iterator_algorithm == IteratorAlgorithm.Optimized: + if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: + continue + new_operation = Conv2dOperation(ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, B, C, + element_epilogue, + stride_support, + epilogue_functor, + swizzling_functor_, + group_mode=GroupMode.Depthwise) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations ################################################################################################### ################################################################################################### @@ -748,10 +823,83 @@ def GenerateSM60_Simt(manifest, cuda_version): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) # +def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [8,] + + filter_3x3 = [3, 3] + filter_5x5 = [5, 5] + + # [stride_h, stride_w] + # [-1, -1] means all stride size. + strides = [[-1,-1], [1, 1], [2, 2]] + # [dilation_h, dilation_w] + # [-1, -1] means all dilation size. + dilations = [[-1,-1], [1, 1], [2, 2]] + + #groups per thread block + g16 = 16 + g32 = 32 + g64 = 64 + + #output shape per thread block + npq_1x4x4 = [1, 4, 4] + npq_1x8x8 = [1, 8, 8] + npq_1x10x10 = [1, 10, 10] + + tile_descriptions = [] + for math_inst in math_instructions: + for stride, dilation in product(strides, dilations): + tile_descriptions.extend([ + # filter3x3 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + + # filter5x5 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc) + ]) + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# # def GenerateSM60(manifest, cuda_version): GenerateSM60_Simt(manifest, cuda_version) + GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version) ################################################################################################### ################################################################################################### @@ -3813,6 +3961,627 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_Simt_complex(manifest, cuda_version) ################################################################################################### + +# +def GenerateSM90_TensorOp_1684(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +################################################################################################### + +# +def GenerateSM90(manifest, cuda_version): + + GenerateSM90_TensorOp_1684(manifest, cuda_version) + GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) + + GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version) + ################################################################################################### if __name__ == "__main__": @@ -3842,6 +4611,8 @@ if __name__ == "__main__": GenerateSM70(manifest, args.cuda_version) GenerateSM75(manifest, args.cuda_version) GenerateSM80(manifest, args.cuda_version) + GenerateSM90(manifest, args.cuda_version) + if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 93c93ec6..3dd57409 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -471,9 +471,11 @@ SharedMemPerCC = { 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM - 80: 160, # 164KB of SMEM - 4KB reserved for the driver - 86: 100, # 100KB of SMEM - 87: 160, # 164KB of SMEM - 4KB reserved for the driver + 80: 163, # 163KB of SMEM - 1KB reserved for the driver + 86: 99, # 99KB of SMEM - 1KB reserved for the driver + 87: 163, # 163KB of SMEM - 1KB reserved for the driver + 89: 99, # 99KB of SMEM - 1KB reserved for the driver + 90: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### @@ -561,7 +563,8 @@ class SwizzlingFunctor(enum.Enum): StridedDgradIdentity1 = enum_auto() StridedDgradIdentity4 = enum_auto() StridedDgradHorizontal = enum_auto() - + StreamK = enum_auto() + # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', @@ -572,6 +575,7 @@ SwizzlingFunctorTag = { SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', + SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', } # @@ -618,38 +622,65 @@ class IteratorAlgorithm(enum.Enum): Optimized = enum_auto() FixedChannels = enum_auto() FewChannels = enum_auto() + FixedStrideDilation = enum_auto() # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', - IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels' + IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', + IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: 'analytic', IteratorAlgorithm.Optimized: 'optimized', IteratorAlgorithm.FixedChannels: 'fixed_channels', - IteratorAlgorithm.FewChannels: 'few_channels' + IteratorAlgorithm.FewChannels: 'few_channels', + IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' } # class StrideSupport(enum.Enum): Strided = enum_auto() Unity = enum_auto() + Fixed = enum_auto() # StrideSupportTag = { StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', + StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' } StrideSupportNames = { StrideSupport.Strided: '', StrideSupport.Unity: 'unity_stride', + StrideSupport.Fixed: 'fixed_stride' } +# +class GroupMode(enum.Enum): + NoneGroup = enum_auto() # dense conv (G=1) + SingleGroup = enum_auto() # grouped convolution (single group per CTA) + MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) + Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + +# +GroupModeTag = { + GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', + GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', + GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', + GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', +} + +GroupModeNames = { + GroupMode.NoneGroup: '', + GroupMode.SingleGroup: 'single_group', + GroupMode.MultipleGroup: 'multiple_group', + GroupMode.Depthwise: 'depthwise', +} ################################################################################################### @@ -677,6 +708,39 @@ class TileDescription: def procedural_name(self): return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + # class TensorDescription: def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md index 498e1aa8..1fd905e6 100644 --- a/tools/library/scripts/pycutlass/README.md +++ b/tools/library/scripts/pycutlass/README.md @@ -85,18 +85,11 @@ You can run the PyCUTLASS on NGC PyTorch container. ```shell docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.09-py3 ``` -PyCUTLASS requires additional dependency Boost C++ library, which can be installed with -```bash -apt-get update -apt-get -y install libboost-all-dev -``` - - ### Environment variables PyCUTLASSS requires two environment variables: -* `CUTLASS_PATH`: the root directory of CUTLASS -* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed +* `CUTLASS_PATH`: the root directory of CUTLASS. You can set this from the location at which you cloned CUTLASS via: `export CUTLASS_PATH=$(pwd)`. +* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed. If running in bash with `nvcc` installed under a CUDA toolkit, you can set this to the location of your `nvcc` installation via: `export CUDA_INSTALL_PATH=$(which nvcc | awk -F'/bin/nvcc' '{print $1}')` After setting these two environment variables, PyCUTLASS can be installed with ```shell diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h index 6a840ce1..0499d709 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm_universal_with_visitor.h @@ -38,6 +38,7 @@ #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" @@ -104,16 +105,12 @@ public: // /// Argument structure - struct Arguments { + struct Arguments : UniversalArgumentsBase { // // Data members // - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - typename EpilogueVisitor::Arguments epilogue_visitor; void const * ptr_A; @@ -124,7 +121,6 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; typename LayoutA::Stride stride_a; typename LayoutB::Stride stride_b; @@ -145,8 +141,6 @@ public: // Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_gather_A_indices(nullptr), ptr_gather_B_indices(nullptr), @@ -174,12 +168,10 @@ public: int const *ptr_gather_B_indices = nullptr, int const *ptr_scatter_D_indices = nullptr ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue_visitor(epilogue_visitor), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), ptr_scatter_D_indices(ptr_scatter_D_indices) { @@ -212,12 +204,10 @@ public: int const *ptr_gather_B_indices = nullptr, int const *ptr_scatter_D_indices = nullptr ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), + UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), epilogue_visitor(epilogue_visitor), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), ptr_scatter_D_indices(ptr_scatter_D_indices) { @@ -248,11 +238,19 @@ public: // /// Parameters structure - struct Params { + struct Params : UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC> { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; + using ParamsBase = UniversalParamsBase< + ThreadblockSwizzle, + ThreadblockShape, + ElementA, + ElementB, + ElementC>; typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; @@ -261,10 +259,6 @@ public: typename EpilogueVisitor::Params epilogue_visitor; - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - void * ptr_A; void * ptr_B; void * ptr_C; @@ -273,7 +267,6 @@ public: int64_t batch_stride_A; int64_t batch_stride_B; int64_t batch_stride_C; - int64_t batch_stride_D; int * ptr_gather_A_indices; int * ptr_gather_B_indices; @@ -285,47 +278,21 @@ public: // Methods // - CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - ptr_gather_A_indices(nullptr), - ptr_gather_B_indices(nullptr), - ptr_scatter_D_indices(nullptr), - semaphore(nullptr) { } + /// Default constructor + Params() = default; CUTLASS_HOST_DEVICE Params( Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr + int device_sms, + int sm_occupancy ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + ParamsBase(args, device_sms, sm_occupancy), params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), epilogue_visitor(args.epilogue_visitor), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), ptr_A(const_cast(args.ptr_A)), ptr_B(const_cast(args.ptr_B)), ptr_C(const_cast(args.ptr_C)), @@ -333,11 +300,9 @@ public: batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), - ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)), - semaphore(static_cast(workspace)) { + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) { } @@ -358,7 +323,6 @@ public: batch_stride_A = args.batch_stride_A; batch_stride_B = args.batch_stride_B; batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; epilogue_visitor = args.epilogue_visitor; @@ -466,12 +430,6 @@ public: return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { diff --git a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h index f6ff7565..7869ce12 100644 --- a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h +++ b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h @@ -38,11 +38,19 @@ #include "cutlass/gemm/threadblock/threadblock_swizzle.h" #include "cutlass/conv/threadblock/threadblock_swizzle.h" -#include +#include #include namespace py = pybind11; +std::string demangle(const char* mangled_name) { + std::size_t len = 0; + int status = 0; + std::unique_ptr ptr( + __cxxabiv1::__cxa_demangle(mangled_name, nullptr, &len, &status)); + return ptr.get(); +} + template void bind_identity_swizzle(py::module & m, std::string name) { py::class_(m, name.c_str(), @@ -80,7 +88,7 @@ void bind_identity_swizzle(py::module & m, std::string name) { py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ - return boost::core::demangle(typeid(T).name()); + return demangle(typeid(T).name()); }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); } @@ -101,7 +109,7 @@ void bind_swizzle(py::module & m, std::string name, std::string doc) { py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ - return boost::core::demangle(typeid(T).name()); + return demangle(typeid(T).name()); }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); } @@ -124,7 +132,7 @@ void bind_dgrad_swizzle(py::module & m, std::string name) { }, py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ - return boost::core::demangle(typeid(T).name()); + return demangle(typeid(T).name()); }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); } diff --git a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py index a7b3af03..58ac7c1c 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py @@ -69,9 +69,12 @@ def get_gemm_arguments(epilogue_functor): class _GemmArguments(ctypes.Structure): _fields_ = [ + # Arguments from UniversalArgumentsBase ("mode", ctypes.c_int), ("problem_size", GemmCoord_), ("batch_count", ctypes.c_int), + ("batch_stride_D", ctypes.c_longlong), + # Remaining arguments ("epilogue", _EpilogueOutputOpParams), ("ptr_A", ctypes.c_void_p), ("ptr_B", ctypes.c_void_p), @@ -80,7 +83,6 @@ def get_gemm_arguments(epilogue_functor): ("batch_stride_A", ctypes.c_longlong), ("batch_stride_B", ctypes.c_longlong), ("batch_stride_C", ctypes.c_longlong), - ("batch_stride_D", ctypes.c_longlong), ("stride_a", ctypes.c_longlong), ("stride_b", ctypes.c_longlong), ("stride_c", ctypes.c_longlong), diff --git a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py index 51d30ed4..54719221 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py @@ -229,7 +229,7 @@ class GemmArguments(ArgumentBase): elif operand in ["c", "d"]: tensor_coord = problem_size.mn() else: - raise ValueError("unknonw operand: " + operand) + raise ValueError("unknown operand: " + operand) layout = tensor_layout.packed(tensor_coord) @@ -245,22 +245,27 @@ class GemmArguments(ArgumentBase): ) if self.gemm_mode == cutlass.gemm.Mode.Array: arguments = self.operation.argument_type( - self.gemm_mode, problem_size_, self.batch_count, self.output_op, + # Arguments from UniversalArgumentsBase + self.gemm_mode, problem_size_, self.batch_count, 0, + # Remaining arguments + self.output_op, int(self.ptr_A_array_buffer.ptr), int(self.ptr_B_array_buffer.ptr), int(self.ptr_C_array_buffer.ptr), int(self.ptr_D_array_buffer.ptr), - 0, 0, 0, 0, + 0, 0, 0, self.lda, self.ldb, self.ldc, self.ldd, self.lda, self.ldb, self.ldc, self.ldd, 0, 0, 0 ) else: arguments = self.operation.argument_type( - self.gemm_mode, problem_size_, self.batch_count, self.output_op, + # Arguments from UniversalArgumentsBase + self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D, + # Remaining arguments + self.output_op, int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D), self.batched_stride_A, self.batched_stride_B, self.batched_stride_C, - self.batched_stride_D, self.lda, self.ldb, self.ldc, self.ldd, self.lda, self.ldb, self.ldc, self.ldd, 0, 0, 0 @@ -299,8 +304,7 @@ class GemmArguments(ArgumentBase): arguments, grid_tiled_shape, gemm_k_size = self.arguments res_arg = self.operation.rt_module.get_args( - ctypes.byref(arguments), ctypes.byref(grid_tiled_shape), - gemm_k_size, ctypes.c_void_p(int(device_workspace))) + ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace))) host_workspace = bytearray(res_arg.contents) device_workspace = None @@ -582,10 +586,15 @@ extern "C" { } // Get the params as byte array - char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, \ - cutlass::gemm::GemmCoord* grid_tiled_shape, int gemm_k_size, int* workspace){ + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){ ${operation_name}_base::Params* params; - params = new ${operation_name}_base::Params(*argument, *grid_tiled_shape, gemm_k_size, workspace); + params = new ${operation_name}_base::Params(*argument, + -1, // SM count. Only used for stream-K + -1 // Occupancy. Only used for stream-K + ); + + // Semaphore holds the pointer to the workspace in the Params struct + params->semaphore = workspace; char *bytes = ((char*)(params)); char *output = new char[sizeof(${operation_name}_base::Params)]; diff --git a/tools/library/scripts/pycutlass/src/pycutlass/library.py b/tools/library/scripts/pycutlass/src/pycutlass/library.py index 61b59a6c..f38cb615 100644 --- a/tools/library/scripts/pycutlass/src/pycutlass/library.py +++ b/tools/library/scripts/pycutlass/src/pycutlass/library.py @@ -116,13 +116,11 @@ DataTypeNames = { DataTypeTag = { cutlass.dtype.b1: "cutlass::uint1b_t", - cutlass.dtype.u2: "cutlass::uint2b_t", cutlass.dtype.u4: "cutlass::uint4b_t", cutlass.dtype.u8: "uint8_t", cutlass.dtype.u16: "uint16_t", cutlass.dtype.u32: "uint32_t", cutlass.dtype.u64: "uint64_t", - cutlass.dtype.s2: "cutlass::int2b_t", cutlass.dtype.s4: "cutlass::int4b_t", cutlass.int8: "int8_t", cutlass.dtype.s16: "int16_t", @@ -138,13 +136,11 @@ DataTypeTag = { cutlass.dtype.cf32: "cutlass::complex", cutlass.dtype.ctf32: "cutlass::complex", cutlass.dtype.cf64: "cutlass::complex", - cutlass.dtype.cu2: "cutlass::complex", cutlass.dtype.cu4: "cutlass::complex", cutlass.dtype.cu8: "cutlass::complex", cutlass.dtype.cu16: "cutlass::complex", cutlass.dtype.cu32: "cutlass::complex", cutlass.dtype.cu64: "cutlass::complex", - cutlass.dtype.cs2: "cutlass::complex", cutlass.dtype.cs4: "cutlass::complex", cutlass.dtype.cs8: "cutlass::complex", cutlass.dtype.cs16: "cutlass::complex", diff --git a/tools/library/scripts/pycutlass/test/example/run_all_example.sh b/tools/library/scripts/pycutlass/test/example/run_all_example.sh old mode 100644 new mode 100755 index 8f68bc30..c05eb048 --- a/tools/library/scripts/pycutlass/test/example/run_all_example.sh +++ b/tools/library/scripts/pycutlass/test/example/run_all_example.sh @@ -1,4 +1,4 @@ -pushd $CUTLASS_PATH/examples/40_cutlass_py/ +pushd $CUTLASS_PATH/examples/40_cutlass_py/customizable python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 diff --git a/tools/library/src/conv2d_operation.h b/tools/library/src/conv2d_operation.h index 4ee58f0d..ef93951c 100644 --- a/tools/library/src/conv2d_operation.h +++ b/tools/library/src/conv2d_operation.h @@ -36,9 +36,12 @@ #include #include "cutlass/cutlass.h" #include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" #include "cutlass/conv/kernel/default_conv2d_dgrad.h" #include "cutlass/conv/kernel/default_conv2d_wgrad.h" #include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" #include "cutlass/library/library.h" #include "library_internal.h" @@ -98,9 +101,9 @@ public: description_.tile_description.threadblock_stages = Operator::kStages; description_.tile_description.warp_count = make_Coord( - Operator::ImplicitGemmKernel::WarpCount::kM, - Operator::ImplicitGemmKernel::WarpCount::kN, - Operator::ImplicitGemmKernel::WarpCount::kK); + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); description_.tile_description.math_instruction.instruction_shape = make_Coord( Operator::InstructionShape::kM, @@ -381,6 +384,258 @@ public: } }; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DirectConv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConv2dOperation : public Conv2dOperation { +public: + + using Operator = Operator_; + using Base = Conv2dOperation; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_reordered_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilouge (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + } // namespace library } // namespace cutlass diff --git a/tools/library/src/conv3d_operation.h b/tools/library/src/conv3d_operation.h index 34d88a6e..8a844f8c 100644 --- a/tools/library/src/conv3d_operation.h +++ b/tools/library/src/conv3d_operation.h @@ -98,9 +98,9 @@ public: description_.tile_description.threadblock_stages = Operator::kStages; description_.tile_description.warp_count = make_Coord( - Operator::ImplicitGemmKernel::WarpCount::kM, - Operator::ImplicitGemmKernel::WarpCount::kN, - Operator::ImplicitGemmKernel::WarpCount::kK); + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); description_.tile_description.math_instruction.instruction_shape = make_Coord( Operator::InstructionShape::kM, diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 84ce0a91..0d3ee187 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -305,7 +305,7 @@ public: Operator *op = static_cast(host_workspace); - status = op->update(args, device_workspace); + status = op->update(args); if (status != Status::kSuccess) { return status; @@ -507,7 +507,7 @@ public: Operator *op = static_cast(host_workspace); - status = op->update(args, device_workspace); + status = op->update(args); if (status != Status::kSuccess) { return status; @@ -725,8 +725,8 @@ public: } Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); + + status = op->update(args); if (status != Status::kSuccess) { return status; @@ -917,30 +917,30 @@ public: /// Runs the kernel virtual Status run( void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, + void *host_workspace, + void *device_workspace = nullptr, cudaStream_t stream = nullptr) const { OperatorArguments args; - + Status status = update_arguments_( - args, + args, static_cast(arguments_ptr)); if (status != Status::kSuccess) { return status; } - + Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); + + status = op->update(args); if (status != Status::kSuccess) { return status; } - + status = op->run(stream); - + return status; } }; @@ -1134,7 +1134,7 @@ public: Operator *op = static_cast(host_workspace); - status = op->update(args, device_workspace); + status = op->update(args); if (status != Status::kSuccess) { return status; @@ -1336,7 +1336,7 @@ public: Operator *op = static_cast(host_workspace); - status = op->update(args, device_workspace); + status = op->update(args); if (status != Status::kSuccess) { return status; diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index c619b157..6074175a 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -117,3 +117,4 @@ cutlass_add_executable_tests( CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_SYMM DISABLE_EXECUTABLE_INSTALL_RULE ) + diff --git a/tools/profiler/src/conv2d_operation_profiler.cu b/tools/profiler/src/conv2d_operation_profiler.cu index 95ebd09d..f5ae5418 100644 --- a/tools/profiler/src/conv2d_operation_profiler.cu +++ b/tools/profiler/src/conv2d_operation_profiler.cu @@ -67,6 +67,7 @@ Conv2dOperationProfiler::Conv2dOperationProfiler(Options const &options): {ArgumentTypeID::kInteger, {"s", "filter_s"}, "Filter S dimension of the Conv2d problem space"}, {ArgumentTypeID::kInteger, {"p", "output_p"}, "Output P dimension of the Conv2d problem space"}, {ArgumentTypeID::kInteger, {"q", "output_q"}, "Output Q dimension of the Conv2d problem space"}, + {ArgumentTypeID::kInteger, {"g", "groups"}, "Number of convolution groups"}, {ArgumentTypeID::kInteger, {"pad_h"}, "Padding in H direction"}, {ArgumentTypeID::kInteger, {"pad_w"}, "Padding in W direction"}, {ArgumentTypeID::kInteger, {"stride_h"}, "Stride in H direction"}, @@ -233,6 +234,11 @@ Status Conv2dOperationProfiler::initialize_configuration( problem_.s = 3; } + if (!arg_as_int(problem_.groups, "g", problem_space, problem)) { + // default value + problem_.groups = 1; + } + if (!arg_as_int(problem_.pad_h, "pad_h", problem_space, problem)) { // default value problem_.pad_h = 1; @@ -382,7 +388,7 @@ Status Conv2dOperationProfiler::initialize_configuration( int(problem_.dilation_w), static_cast(static_cast(problem_.conv_mode)), int(problem_.split_k_slices), - 1 // groups + int(problem_.groups) ); conv_workspace_.configuration.split_k_mode = static_cast(static_cast(problem_.split_k_mode)); @@ -454,6 +460,8 @@ void Conv2dOperationProfiler::initialize_result_( set_argument(result, "p", problem_space, problem_.p); set_argument(result, "q", problem_space, problem_.q); + set_argument(result, "g", problem_space, problem_.groups); + set_argument(result, "pad_h", problem_space, problem_.pad_h); set_argument(result, "pad_w", problem_space, problem_.pad_w); @@ -624,6 +632,19 @@ Status Conv2dOperationProfiler::initialize_workspace( conv_workspace_.problem_count ); + if(problem_.groups == problem_.c && problem_.groups == problem_.k){ + // Depthwise direct conv kernel needs reorder the filter. + conv_workspace_.reordered_B = device_context.allocate_tensor( + options, + "B", + operation_desc.B.element, + operation_desc.B.layout, + problem_.extent_b(operation_desc.conv_kind), + conv_workspace_.configuration.stride_b, + conv_workspace_.problem_count + ); + } + conv_workspace_.C = device_context.allocate_tensor( options, "C", @@ -738,6 +759,12 @@ bool Conv2dOperationProfiler::verify_cutlass( conv_workspace_.arguments.beta = problem_.beta.data(); conv_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + if (conv_workspace_.reordered_B != nullptr){ + conv_workspace_.arguments.reordered_B = conv_workspace_.reordered_B->data(); + }else{ + conv_workspace_.arguments.reordered_B = nullptr; + } + conv_workspace_.Computed->copy_from_device(conv_workspace_.C->data()); if (conv_workspace_.configuration.split_k_mode == conv::SplitKMode::kParallel) { diff --git a/tools/profiler/src/conv2d_operation_profiler.h b/tools/profiler/src/conv2d_operation_profiler.h index 39eead34..9af57b09 100644 --- a/tools/profiler/src/conv2d_operation_profiler.h +++ b/tools/profiler/src/conv2d_operation_profiler.h @@ -75,6 +75,7 @@ public: struct Conv2dProblem { int64_t n, h, w, c, p, q, k, r, s; + int64_t groups; int64_t pad_h, pad_w; int64_t stride_h, stride_w; int64_t dilation_h, dilation_w; @@ -114,7 +115,7 @@ public: cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { switch (conv_kind) { - case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c)); + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); @@ -136,7 +137,7 @@ public: std::vector extent_b(library::ConvKind const &conv_kind) const { switch (conv_kind) { - case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c)}; + case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); @@ -228,6 +229,7 @@ public: /// Conv device allocations DeviceAllocation *A; DeviceAllocation *B; + DeviceAllocation *reordered_B; DeviceAllocation *C; DeviceAllocation *Computed; DeviceAllocation *Reference; @@ -270,6 +272,7 @@ public: Conv2dWorkspace() : A(nullptr), B(nullptr), + reordered_B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) {} @@ -317,10 +320,10 @@ public: stride_activations.push_back(int(problem.h) * int(problem.w) * int(problem.c)); - stride_filters.push_back(int(problem.c)); - stride_filters.push_back(int(problem.s) * int(problem.c)); + stride_filters.push_back(int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); stride_filters.push_back(int(problem.r) * int(problem.s) * - int(problem.c)); + int(problem.c / problem.groups)); stride_output.push_back(int(problem.k)); stride_output.push_back(int(problem.q) * int(problem.k)); diff --git a/tools/profiler/src/cudnn_helpers.cpp b/tools/profiler/src/cudnn_helpers.cpp index 0ab6c12a..61b19622 100644 --- a/tools/profiler/src/cudnn_helpers.cpp +++ b/tools/profiler/src/cudnn_helpers.cpp @@ -195,7 +195,12 @@ bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescripti return true; } case library::OpcodeClassID::kSimt: - return false; + #if (defined(CUDNN_VERSION) && CUDNN_VERSION <= 8000) + cudnn_math_type = CUDNN_DEFAULT_MATH; + #else + cudnn_math_type = CUDNN_FMA_MATH; + #endif + return true; } return false; diff --git a/tools/profiler/src/cudnn_helpers.h b/tools/profiler/src/cudnn_helpers.h index f18e78a0..9488e8b0 100644 --- a/tools/profiler/src/cudnn_helpers.h +++ b/tools/profiler/src/cudnn_helpers.h @@ -245,7 +245,7 @@ struct cudnnConvDispatcher { data_type_filter, layout_filter, configuration.problem_size.K, - configuration.problem_size.C, + configuration.problem_size.C / configuration.problem_size.groups, configuration.problem_size.R, configuration.problem_size.S )); diff --git a/tools/util/include/cutlass/util/host_uncompress.h b/tools/util/include/cutlass/util/host_uncompress.h index 7fbfc3dd..38c189aa 100644 --- a/tools/util/include/cutlass/util/host_uncompress.h +++ b/tools/util/include/cutlass/util/host_uncompress.h @@ -42,6 +42,7 @@ namespace cutlass { +// uncompress sparse tensor core A matrix template void uncompress(TensorRef uncompressed_tensor_a, @@ -119,5 +120,38 @@ void uncompress(TensorRef uncompressed_tensor_a, } } } + +// uncompress ELL block sparse matrix +template +void uncompress_ell_block_sparse( + TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef ell_idx, + int rows, int cols, + int ell_num_cols, int ell_blocksize) { + + for (int r = 0; r < rows / ell_blocksize; ++r) { + for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { + + ElementE idx = ell_idx.at(MatrixCoord(r, c)); + + if (idx != -1) { + int row_begin = r * ell_blocksize; + int col_begin_real = idx * ell_blocksize; + int col_begin = c * ell_blocksize; + + for (int i = 0; i < ell_blocksize; ++i) { + for (int j = 0; j < ell_blocksize; ++j) { + uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = + tensor_a.at( + MatrixCoord(row_begin + i, col_begin +j)); + } + } + } + } + } +} + } // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h index 17d008ec..9f2cc0d3 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -68,8 +68,8 @@ struct TensorFuncBinaryOp { /// View of left-hand-side tensor TensorView view_d; - TensorRef ref_a; - TensorRef ref_b; + TensorRef view_a; + TensorRef view_b; BinaryFunc func; // @@ -82,8 +82,8 @@ struct TensorFuncBinaryOp { /// Constructor TensorFuncBinaryOp( TensorView const & view_d_, - TensorRef const & ref_a_, - TensorRef const & ref_b_, + TensorRef const & view_a_, + TensorRef const & view_b_, BinaryFunc func = BinaryFunc() ): view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } @@ -284,7 +284,7 @@ void TensorDiv( TensorView d, ///< destination tensor view TensorRef a ///< A tensor reference ) { - TensorMul(d, d, a); + TensorDiv(d, d, a); } @@ -312,7 +312,7 @@ void TensorModulus( LayoutA, ElementB, LayoutB, - cutlass::modulus + cutlass::divides > func(d, a, b); TensorForEach( @@ -331,7 +331,7 @@ void TensorModulus( TensorView d, ///< destination tensor view TensorRef a ///< A tensor reference ) { - TensorMul(d, d, a); + TensorDiv(d, d, a); } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index a795d5f8..0a3bd94d 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -1272,7 +1272,7 @@ template struct RandomSparseMetaFunc { uint64_t seed; - double range; + int range; int MetaSizeInBits; // @@ -1302,9 +1302,8 @@ struct RandomSparseMetaFunc { Element result = 0x0; for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { - double rnd = double(std::rand()) / double(RAND_MAX); - rnd = range * rnd; - Element meta = MetaArray[(int)rnd]; + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; result = (Element)(result | ((Element)(meta << (i * 4)))); } @@ -1393,6 +1392,37 @@ void BlockFillRandomSparseMeta( } } +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Copies a diagonal in from host memory without modifying off-diagonal elements.