Updates for CUTLASS 3.4.1 (#1346)
* Updates for CUTLASS 3.4.1 * minor epi change
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
@ -56,9 +56,6 @@
|
||||
#include "gemm_testbed_3x_evt.hpp"
|
||||
#include "sm90_evt_operations.hpp"
|
||||
|
||||
|
||||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
@ -132,7 +129,7 @@ bool testEVTAuxStoreWithoutD() {
|
||||
D_block.reset(m * n);
|
||||
aux_store_D_block.reset(m * n);
|
||||
Gemm gemm_op_base;
|
||||
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
|
||||
auto stride_B = cutlass::make_cute_packed_stride(
|
||||
@ -141,7 +138,7 @@ bool testEVTAuxStoreWithoutD() {
|
||||
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
auto stride_D = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
|
||||
|
||||
auto arguments_base = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
@ -178,12 +175,12 @@ bool testEVTAuxStoreWithoutD() {
|
||||
/*hw_info=*/{},
|
||||
/*scheduler_args=*/{}
|
||||
};
|
||||
|
||||
|
||||
constexpr float beta [[maybe_unused]] = 1.0;
|
||||
constexpr float alpha [[maybe_unused]] = 1.0;
|
||||
|
||||
|
||||
using ElementC = typename GemmWithoutD::ElementC;
|
||||
|
||||
|
||||
if constexpr (not has_c) {
|
||||
arguments_base.epilogue.thread = {
|
||||
// binary op : alpha * acc
|
||||
@ -282,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -324,10 +321,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -352,7 +349,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -394,10 +391,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -467,7 +464,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -492,7 +489,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -534,10 +531,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -562,7 +559,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
@ -604,10 +601,10 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
@ -677,7 +674,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user