CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -49,16 +49,16 @@ struct ContractionKernel {
using ElementScalar = float;
using ElementAccum = float;
using EpilogueThread = cutlass::epilogue::thread::LinearCombination<ElementC,
1,
ElementAccum,
ElementScalar>;
using EpilogueThread = cutlass::epilogue::thread::LinearCombination<ElementC,
1,
ElementAccum,
ElementScalar>;
static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K;
static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN;
/// Kernel config
typedef int64_t stride_type;
typedef int64_t stride_type;
typedef int32_t extent_type;
static constexpr const stride_type* stride_null = nullptr;
@ -117,7 +117,7 @@ using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder<
using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue<StrideC, StrideC, EpilogueThread, cutlass::gemm::EpilogueDefault>;
using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<EpilogueOutputOp>;
using Kernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
ProblemShape,
CollectiveOp,
CollectiveEpilogue>;