@ -49,16 +49,16 @@ struct ContractionKernel {
|
||||
|
||||
using ElementScalar = float;
|
||||
using ElementAccum = float;
|
||||
using EpilogueThread = cutlass::epilogue::thread::LinearCombination<ElementC,
|
||||
1,
|
||||
ElementAccum,
|
||||
ElementScalar>;
|
||||
using EpilogueThread = cutlass::epilogue::thread::LinearCombination<ElementC,
|
||||
1,
|
||||
ElementAccum,
|
||||
ElementScalar>;
|
||||
|
||||
static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K;
|
||||
static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN;
|
||||
|
||||
/// Kernel config
|
||||
typedef int64_t stride_type;
|
||||
typedef int64_t stride_type;
|
||||
typedef int32_t extent_type;
|
||||
|
||||
static constexpr const stride_type* stride_null = nullptr;
|
||||
@ -117,7 +117,7 @@ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue<StrideC, StrideC, EpilogueThread, cutlass::gemm::EpilogueDefault>;
|
||||
using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<EpilogueOutputOp>;
|
||||
using Kernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
ProblemShape,
|
||||
CollectiveOp,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user