Improve sm90 mixed dtype kernel (#1883)

This commit is contained in:
Sergey Klevtsov
2024-10-17 17:06:38 -07:00
committed by GitHub
parent 755194a7bd
commit 08101d9d0c
11 changed files with 994 additions and 80 deletions

View File

@ -340,7 +340,7 @@ auto
all_of(T const& t, F&& f)
{
if constexpr (is_tuple<T>::value) {
return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq<T>{});
return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq<T>{});
} else {
return f(t);
}

View File

@ -198,13 +198,22 @@ is_major(Stride = {}) {
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(cute::remove_pointer_t<Stride>{})))>::value;
}
template<int ModeIndex, class Shape, class Stride>
constexpr bool
is_major(cute::Layout<Shape,Stride> = {}) {
return is_major<ModeIndex>(Stride{});
}
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
template<class StrideA>
constexpr
auto
stride_to_layout_tag_A() {
using InternalStrideA = cute::remove_pointer_t<StrideA>;
if constexpr (is_major<0, StrideA>()) { // M major
if constexpr (cute::is_layout<InternalStrideA>::value) {
return stride_to_layout_tag_A<decltype(cute::stride(InternalStrideA{}))>();
}
else if constexpr (is_major<0, StrideA>()) { // M major
return layout::ColumnMajor{};
}
// Specialize for sparse layout
@ -224,7 +233,11 @@ template<class StrideB>
constexpr
auto
stride_to_layout_tag_B() {
if constexpr (is_major<0, StrideB>()) { // N major
using InternalStrideB = cute::remove_pointer_t<StrideB>;
if constexpr (cute::is_layout<InternalStrideB>::value) {
return stride_to_layout_tag_B<decltype(cute::stride(InternalStrideB{}))>();
}
else if constexpr (is_major<0, StrideB>()) { // N major
return layout::RowMajor{};
}
else { // K major
@ -238,7 +251,11 @@ template<class StrideC>
constexpr
auto
stride_to_layout_tag_C() {
if constexpr (is_major<0, StrideC>()) { // M major
using InternalStrideC = cute::remove_pointer_t<StrideC>;
if constexpr (cute::is_layout<InternalStrideC>::value) {
return stride_to_layout_tag_C<decltype(cute::stride(InternalStrideC{}))>();
}
else if constexpr (is_major<0, StrideC>()) { // M major
return layout::ColumnMajor{};
}
else { // N major
@ -349,28 +366,25 @@ get_output_alignment_bits() {
return 128;
}
// Return the shape that is associated with stride-1 mode, or 1 if not found
template<typename Shape, typename Stride>
CUTLASS_HOST_DEVICE constexpr
auto
get_contiguous_shape(Shape const & shape, Stride const & stride) {
using namespace cute;
auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; });
return get<decltype(idx)::value>(append(flatten(shape), _1{}));
}
// Check if tensor shape satisfies a given major alignment
// Check if tensor layout satisfies a given major alignment
template<int Alignment, class Shape, class Stride>
CUTLASS_HOST_DEVICE constexpr
bool
check_alignment(Shape const & shape, Stride const & stride) {
return is_major<0>(stride)
? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0
: get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0;
check_alignment(cute::Layout<Shape,Stride> const& layout) {
// Condition: shape must divide by Alignment without rounding
bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast<Alignment>(layout));
// Condition: every dynamic stride must be a multiple of Alignment
bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static<decltype(s)>::value || (s % Alignment == 0); });
return shape_check && stride_check;
}
// Check if tensor shape satisfies a given major alignment
// Check if tensor layout satisfies a given major alignment
template<int Alignment, class Shape, class Stride>
CUTLASS_HOST_DEVICE constexpr
bool
check_alignment(Shape const& shape, Stride const& stride) {
return check_alignment<Alignment>(cute::make_layout(shape, stride));
}
template<int B, int M, int S>
CUTLASS_HOST_DEVICE constexpr

View File

@ -327,13 +327,23 @@ public:
if constexpr (is_destination_supported) {
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
if constexpr (cute::is_same_v<CopyOpS2G, SM90_TMA_STORE_IM2COL>) { // ignore L stride for implicit gemm
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(take<0,2>(shape), take<0,2>(StrideD{}));
}
else {
implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(shape, StrideD{});
}
}
if constexpr (not cute::is_void_v<ElementC>) {
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
if constexpr (cute::is_same_v<CopyOpG2S, SM90_TMA_LOAD_IM2COL>) { // ignore L stride for implicit gemm
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(take<0,2>(shape), take<0,2>(StrideC{}));
}
else {
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(shape, StrideC{});
}
}
if (!implementable) {

View File

@ -409,8 +409,18 @@ public:
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
using GmemLayoutATag = GmemLayoutATag_;
using GmemLayoutBTag = GmemLayoutBTag_;
template<class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<T>::value) {
return t;
}
else {
return cute::stride(t);
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA = cute::conditional_t<IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementPairA_>;
using ElementPairB = cute::conditional_t<!IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementPairB_>;
@ -464,8 +474,8 @@ public:
using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = TagToStrideA_t<GmemLayoutATag>;
using StrideB = TagToStrideB_t<GmemLayoutBTag>;
using StrideA = cute::conditional_t<cute::is_layout<GmemLayoutATag_>::value, GmemLayoutATag_, TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<cute::is_layout<GmemLayoutBTag_>::value, GmemLayoutBTag_, TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMma<
DispatchPolicy,

View File

@ -182,6 +182,7 @@ public:
using InternalSmemLayoutAtomB = cute::conditional_t<!SwapAB, SmemLayoutAtomB, SmemLayoutAtomA>;
using InternalSmemCopyAtomA = cute::conditional_t<!SwapAB, SmemCopyAtomA, SmemCopyAtomB>;
using InternalSmemCopyAtomB = cute::conditional_t<!SwapAB, SmemCopyAtomB, SmemCopyAtomA>;
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
@ -228,14 +229,25 @@ public:
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
InternalSmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
InternalSmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
template<class LayoutAtom, class TileShape, class Stride>
static constexpr
CUTLASS_HOST_DEVICE
auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) {
if constexpr (not cute::is_layout<Stride>::value) {
return tile_to_shape(
layout_atom,
append(tile_shape, Int<DispatchPolicy::Stages>{}),
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{});
}
else {
auto gmem_tile = composition(stride, tile_shape);
return make_layout_like(append(gmem_tile, make_layout(Int<DispatchPolicy::Stages>{}, 0)));
}
}
using SmemLayoutA = decltype(get_smem_layout(InternalSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalStrideA{}));
using SmemLayoutB = decltype(get_smem_layout(InternalSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalStrideB{}));
// It is assumed that the scales and zero-points share the same smem layout
using SmemLayoutScale = decltype(tile_to_shape(
@ -381,6 +393,18 @@ public:
uint32_t mma_promotion_interval = 4;
};
template<class Shape, class Stride>
static constexpr
CUTLASS_HOST_DEVICE
auto get_gmem_layout(Shape const& shape, Stride const& stride) {
if constexpr (not cute::is_layout<Stride>::value) {
return make_layout(shape, stride);
}
else {
return stride;
}
}
// Device side kernel params
struct Params {
private:
@ -394,10 +418,14 @@ public:
TransformB_>;
public:
// Assumption: StrideA is congruent with Problem_MK
using LayoutA = decltype(get_gmem_layout(repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}));
using LayoutB = decltype(get_gmem_layout(repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}));
using TMA_A = decltype(make_tma_copy_A_sm90<TmaElementA>(
GmemTiledCopyA{},
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), LayoutA{}),
SmemLayoutA{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{})); // mcast along N mode for this M load, if any
@ -419,7 +447,7 @@ public:
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy_B_sm90(
GmemTiledCopyB{},
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), LayoutB{}),
SmemLayoutB{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{})); // mcast along M mode for this N load, if any
@ -431,6 +459,8 @@ public:
int group_size;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
InternalStrideA dA;
InternalStrideB dB;
};
//
@ -469,8 +499,8 @@ public:
dB = args.dA;
}
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA));
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB));
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), get_gmem_layout(make_shape(M,K,L), dA));
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), get_gmem_layout(make_shape(N,K,L), dB));
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90<TmaElementA>(
GmemTiledCopyA{},
tensor_a,
@ -490,7 +520,7 @@ public:
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
}
else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
@ -505,7 +535,7 @@ public:
_1{}); // mcast along N mode for this M load, if any
if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) {
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
}
else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS));
@ -515,7 +545,7 @@ public:
SmemLayoutScale{}(_,_,cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in to_underlying_arguments.");
}
@ -533,33 +563,37 @@ public:
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
bool check_aligned_A = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(get_gmem_layout(cute::make_shape(M,K,L), args.dA));
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
bool check_aligned_B = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(get_gmem_layout(cute::make_shape(N,K,L), args.dB));
bool check_aligned_S = true;
bool check_aligned_Z = true;
bool check_mode_args = true;
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
implementable = implementable && (args.ptr_S == nullptr);
implementable = implementable && (args.ptr_Z == nullptr);
check_mode_args = check_mode_args && (args.ptr_S == nullptr);
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.group_size != 0;
implementable = implementable && (args.ptr_S != nullptr);
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
check_mode_args = check_mode_args && args.group_size != 0;
check_mode_args = check_mode_args && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.ptr_Z != nullptr);
check_aligned_Z = cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.ptr_Z != nullptr);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
@ -569,10 +603,23 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
if (!check_mode_args) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n");
}
return implementable;
if (!check_aligned_A) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_B) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_S) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n");
}
if (!check_aligned_Z) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n");
}
return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
@ -618,8 +665,8 @@ public:
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
@ -680,8 +727,6 @@ public:
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
int lane_predicate = cute::elect_one_sync();
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
@ -748,8 +793,10 @@ public:
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
if (cute::elect_one_sync()) {
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do.
@ -920,6 +967,12 @@ public:
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
if (k_block < K_BLOCK_MAX - 2) { // prefetch next block
copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view,
partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage);
@ -927,11 +980,6 @@ public:
if (k_block < K_BLOCK_MAX - 1) {
transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1);
}
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
}
--k_tile_count;

View File

@ -70,7 +70,7 @@ using cutlass::detail::StrideToLayoutTagC_t;
template<int ModeIndex, class Stride>
constexpr bool
is_major(Stride = {}) {
return ::cutlass::detail::is_major<ModeIndex, Stride>();
return ::cutlass::detail::is_major<ModeIndex>(Stride{});
}
template<class Stride>