Updates for CUTLASS 3.4.1 (#1346)

* Updates for CUTLASS 3.4.1

* minor epi change
This commit is contained in:
ANIKET SHIVAM
2024-02-15 12:48:34 -08:00
committed by GitHub
parent 47a3ebbea9
commit bbe579a9e3
49 changed files with 800 additions and 451 deletions

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
@ -56,9 +56,6 @@
#include "gemm_testbed_3x_evt.hpp"
#include "sm90_evt_operations.hpp"
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
using namespace cute;
@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() {
D_block.reset(m * n);
aux_store_D_block.reset(m * n);
Gemm gemm_op_base;
auto stride_A = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
auto stride_B = cutlass::make_cute_packed_stride(
@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() {
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
auto stride_D = cutlass::make_cute_packed_stride(
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
auto arguments_base = typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() {
/*hw_info=*/{},
/*scheduler_args=*/{}
};
constexpr float beta [[maybe_unused]] = 1.0;
constexpr float alpha [[maybe_unused]] = 1.0;
using ElementC = typename GemmWithoutD::ElementC;
if constexpr (not has_c) {
arguments_base.epilogue.thread = {
// binary op : alpha * acc
@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
>;
using namespace cutlass::epilogue::fusion;
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
return *(GemmKernel *)(nullptr);
};
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;