cherry-pick feature/hopper-blockwise-generalization-optimization (#2270)

This commit is contained in:
Lain
2025-04-29 13:47:22 -07:00
committed by GitHub
parent 697126019e
commit 2b78c2fe31
9 changed files with 503 additions and 216 deletions

View File

@ -322,7 +322,11 @@ struct DescriptorIterator
CUTE_HOST_DEVICE constexpr
DescriptorIterator operator+(Index const& offset) const
{
return { GmmaDescriptor{desc_ + uint64_t(offset)} };
// Use 32bit calculation rather than 64 bit calculation as we only update the part of desc
GmmaDescriptor ret;
ret.reg32_[0] = desc_.reg32_[0] + uint32_t(offset);
ret.reg32_[1] = desc_.reg32_[1];
return { ret };
}
};

View File

@ -1065,6 +1065,7 @@ struct CollectiveBuilder<
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum>) and
not detail::is_use_rmem_A<ElementA, GmemLayoutPairA, ElementB, GmemLayoutPairB>()
>

View File

@ -77,7 +77,6 @@ private:
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
CUTLASS_DEVICE
void scale_core(ElementAccumulator const &scale) {
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
@ -96,7 +95,6 @@ private:
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale(i);
@ -121,7 +119,6 @@ private:
static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape.");
static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape.");
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scaleA(i) * scaleB(i);

View File

@ -105,6 +105,7 @@ struct CollectiveMma<
using ElementBlockScale = ElementAccumulator;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using GmemTiledCopyScaleTMA = cute::SM90_TMA_LOAD;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
@ -118,9 +119,6 @@ struct CollectiveMma<
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
// 33 threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
static constexpr int NumProducerThreadEvents = 33;
static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{});
static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{});
static constexpr int ScaleGranularityK = size<1,0>(LayoutSFA{});
@ -133,6 +131,12 @@ struct CollectiveMma<
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
static constexpr int ScaleTmaThreshold = 32;
static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold;
static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold;
// Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
static constexpr int NumProducerThreadEvents = ((IsTmaLoadSFA && IsTmaLoadSFB)? 1 : 33);
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
@ -191,10 +195,10 @@ struct CollectiveMma<
struct SharedStorage
{
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutSFA>> smem_SFA; // ScaleMsPerTile x k
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutSFB>> smem_SFB; // ScaleNsPerTile x k
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // TILE_M x PIPE_K
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // TILE_N x PIPE_K
CUTE_ALIGNAS(128) cute::array<ElementBlockScale, cute::cosize_v<SmemLayoutSFA>> smem_SFA; // ScaleMsPerTile x PIPE_K
CUTE_ALIGNAS(128) cute::array<ElementBlockScale, cute::cosize_v<SmemLayoutSFB>> smem_SFB; // ScaleNsPerTile x PIPE_K
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
@ -218,29 +222,60 @@ struct CollectiveMma<
// Device side kernel params
struct Params {
static auto getTmaSFA() {
if constexpr (IsTmaLoadSFA) {
return make_tma_copy(
GmemTiledCopyScaleTMA{},
make_tensor(static_cast<ElementBlockScale const*>(nullptr), filter_zeros(LayoutSFA{})),
filter_zeros(SmemLayoutSFA{}(_,_,_0{})),
Shape<Int<ScaleMsPerTile>, Int<1>>{},
_1{});
}
else {
return nullptr;
}
}
static auto getTmaSFB() {
if constexpr (IsTmaLoadSFB) {
return make_tma_copy(
GmemTiledCopyScaleTMA{},
make_tensor(static_cast<ElementBlockScale const*>(nullptr), filter_zeros(LayoutSFB{})),
filter_zeros(SmemLayoutSFB{}(_,_,_0{})),
Shape<Int<ScaleNsPerTile>, Int<1>>{},
_1{});
}
else {
return nullptr;
}
}
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy_A_sm90(
GmemTiledCopyA{},
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_,_,0),
SmemLayoutA{}(_,_,_0{}),
TileShape{},
ClusterShape{}));
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy_B_sm90(
GmemTiledCopyB{},
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_,_,0),
SmemLayoutB{}(_,_,_0{}),
TileShape{},
ClusterShape{}));
// NOTE: Does make_tma_copy supports 0 stride?
using TMA_SFA = decltype(getTmaSFA());
using TMA_SFB = decltype(getTmaSFB());
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_SFA tma_load_sfa;
TMA_SFB tma_load_sfb;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
// Block scaling factors for A and B
ElementBlockScale const* ptr_SFA;
LayoutSFA layout_SFA;
ElementBlockScale const* ptr_SFA;
ElementBlockScale const* ptr_SFB;
LayoutSFA layout_SFA;
LayoutSFB layout_SFB;
};
@ -259,7 +294,11 @@ struct CollectiveMma<
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
auto ptr_SFA = reinterpret_cast<ElementBlockScale const*>(args.ptr_SFA);
auto ptr_SFB = reinterpret_cast<ElementBlockScale const*>(args.ptr_SFB);
Tensor tensor_sfa = make_tensor(ptr_SFA, filter_zeros(args.layout_SFA));
Tensor tensor_sfb = make_tensor(ptr_SFB, filter_zeros(args.layout_SFB));
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
@ -274,20 +313,42 @@ struct CollectiveMma<
SmemLayoutB{}(_,_,cute::Int<0>{}),
TileShape{},
ClusterShape{});
typename Params::TMA_SFA tma_load_sfa{};
if constexpr (IsTmaLoadSFA) {
tma_load_sfa = make_tma_copy(
GmemTiledCopyScaleTMA{},
tensor_sfa,
filter_zeros(SmemLayoutSFA{})(_,_,cute::Int<0>{}),
Shape<Int<ScaleMsPerTile>, Int<1>>{},
_1{});
}
typename Params::TMA_SFB tma_load_sfb{};
if constexpr (IsTmaLoadSFB) {
tma_load_sfb = make_tma_copy(
GmemTiledCopyScaleTMA{},
tensor_sfb,
filter_zeros(SmemLayoutSFB{})(_,_,cute::Int<0>{}),
Shape<Int<ScaleNsPerTile>, Int<1>>{},
_1{});
}
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
uint32_t transaction_bytes_sfa = TmaTransactionBytesSFA;
uint32_t transaction_bytes_sfb = TmaTransactionBytesSFB;
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk + transaction_bytes_sfa + transaction_bytes_sfb;
return {
tma_load_a,
tma_load_b,
tma_load_sfa,
tma_load_sfb,
transaction_bytes,
transaction_bytes_mk,
transaction_bytes_nk,
args.ptr_SFA,
args.layout_SFA,
args.ptr_SFB,
args.layout_SFB
args.layout_SFA,
args.layout_SFB,
};
}
@ -302,20 +363,39 @@ struct CollectiveMma<
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
if (!cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{})) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor A.\n");
}
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if (!cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{})) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor B.\n");
}
constexpr int min_tma_aligned_elements_S = tma_alignment_bits / cutlass::sizeof_bits<ElementBlockScale>::value;
if (IsTmaLoadSFA && !cutlass::detail::check_alignment<min_tma_aligned_elements_S>(args.layout_SFA)) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale A.\n");
}
if (IsTmaLoadSFB && !cutlass::detail::check_alignment<min_tma_aligned_elements_S>(args.layout_SFB)) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n");
}
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
if (args.mma_promotion_interval % 4 != 0 ||
args.mma_promotion_interval != ScalePromotionInterval ||
args.mma_promotion_interval % pipe_k != 0 ||
pipe_k > args.mma_promotion_interval) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Argument mma_promotion_interval is invalid.\n");
}
// We expect full tiles in K
implementable = implementable && (K % size<2>(TileShape{}) == 0);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
if (K % size<2>(TileShape{}) != 0) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n");
}
return implementable;
}
@ -326,7 +406,12 @@ struct CollectiveMma<
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
static constexpr uint32_t TmaTransactionBytesNK =
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
static constexpr uint32_t TmaTransactionBytesSFA =
(IsTmaLoadSFA? cutlass::bits_to_bytes(ScaleMsPerTile * static_cast<uint32_t>(sizeof_bits<ElementBlockScale>::value)): 0);
static constexpr uint32_t TmaTransactionBytesSFB =
(IsTmaLoadSFB? cutlass::bits_to_bytes(ScaleNsPerTile * static_cast<uint32_t>(sizeof_bits<ElementBlockScale>::value)): 0);
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesSFA + TmaTransactionBytesSFB;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
@ -334,6 +419,12 @@ struct CollectiveMma<
{
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
if constexpr (IsTmaLoadSFA) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfa.get_tma_descriptor());
}
if constexpr (IsTmaLoadSFB) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfb.get_tma_descriptor());
}
}
/// Set up the data needed by this collective for load and mma.
@ -357,10 +448,24 @@ struct CollectiveMma<
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (scale_m,k,l)
Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (scale_n,k,l)
// Note that mSFA_mkl and mSFB_nkl are already blocked tiled in the `m` host and
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mSFA_mkl and mSFB_nkl.
auto mSFA_mkl = [&]() {
if constexpr (IsTmaLoadSFA) {
return mainloop_params.tma_load_sfa.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFA)));
}
else {
return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (scale_m,k,l)
}
}();
auto mSFB_nkl = [&]() {
if constexpr (IsTmaLoadSFB) {
return mainloop_params.tma_load_sfb.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFB)));
}
else {
return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (scale_n,k,l)
}
}();
return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl);
}
@ -387,8 +492,8 @@ struct CollectiveMma<
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE)
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (BLK_M,BLK_K,PIPE)
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), filter_zeros(SmemLayoutSFA{})); // (ScaleMsPerTile,PIPE)
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), filter_zeros(SmemLayoutSFB{})); // (ScaleNsPerTile,PIPE)
//
// Prepare the TMA loads for A and B
@ -399,6 +504,8 @@ struct CollectiveMma<
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor mSFA_mkl = get<2>(load_inputs);
Tensor mSFB_nkl = get<3>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
@ -407,9 +514,120 @@ struct CollectiveMma<
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
Tensor gSFA = local_tile(
mSFA_mkl, make_tile(Int<ScaleMsPerTile>{}, Int<1>{}),
make_coord(m_coord,_,l_coord));
Tensor gSFB = local_tile(
mSFB_nkl, make_tile(Int<ScaleNsPerTile>{}, Int<1>{}),
make_coord(n_coord,_,l_coord));
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
auto [tAgA_SFA, tAsA_SFA] = [&]() {
if constexpr (IsTmaLoadSFA) {
auto block_tma_sfa = mainloop_params.tma_load_sfa.get_slice(cluster_local_block_id.y);
Tensor tAgA_SFA_ = block_tma_sfa.partition_S(gSFA);
Tensor tAsA_SFA_ = block_tma_sfa.partition_D(sSFA);
return cute::make_tuple(tAgA_SFA_, tAsA_SFA_);
}
else {
return cute::make_tuple(0, 0);
}
}();
auto [tBgB_SFB, tBsB_SFB] = [&]() {
if constexpr (IsTmaLoadSFB) {
auto block_tma_sfb = mainloop_params.tma_load_sfb.get_slice(cluster_local_block_id.y);
Tensor tBgB_SFB_ = block_tma_sfb.partition_S(gSFB);
Tensor tBsB_SFB_ = block_tma_sfb.partition_D(sSFB);
return cute::make_tuple(tBgB_SFB_, tBsB_SFB_);
}
else {
return cute::make_tuple(0, 0);
}
}();
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_sf = 0;
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
int write_stage = smem_pipe_write.index();
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
// Copy operands A and B from global memory to shared memory
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
// Copy scale tensors from global memory to shared memory
if constexpr (IsTmaLoadSFA) {
if (lane_predicate) {
copy(mainloop_params.tma_load_sfa.with(*tma_barrier, mcast_mask_sf), tAgA_SFA(_,_,_,*k_tile_iter), tAsA_SFA(_,_,_,write_stage));
}
}
if constexpr (IsTmaLoadSFB) {
if (lane_predicate) {
copy(mainloop_params.tma_load_sfb.with(*tma_barrier, mcast_mask_sf), tBgB_SFB(_,_,_,*k_tile_iter), tBsB_SFB(_,_,_,write_stage));
}
}
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
template <
class TensorA, class TensorB,
class TensorScaleA, class TensorScaleB,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load_auxiliary(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (ScaleMsPerTile,k)
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (ScaleNsPerTile,k)
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor mSFA_mkl = get<2>(load_inputs);
Tensor mSFB_nkl = get<3>(load_inputs);
@ -441,35 +659,9 @@ struct CollectiveMma<
Tensor tSFBcSFB_k = thr_scale_copy_b.partition_S(cSFB_k);
Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tSFApSFA = make_tensor<bool>(shape(filter_zeros(tSFAsSFA(_,_,_,_0{})))); // (CPY,CPY_M,CPY_K)
Tensor tSFBpSFB = make_tensor<bool>(shape(filter_zeros(tSFBsSFB(_,_,_,_0{})))); // (CPY,CPY_N,CPY_K)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
auto SFA_shape = shape(mainloop_params.layout_SFA);
auto SFB_shape = shape(mainloop_params.layout_SFB);
@ -480,9 +672,9 @@ struct CollectiveMma<
pipeline.producer_acquire(smem_pipe_write);
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
bool load_sfa = thread_idx < ScaleMsPerTile;
Tensor tSFAcSFA = tSFAcSFA_k(_,_,_,*k_tile_iter);
Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA);
bool load_sfa = thread_idx < ScaleMsPerTile;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tSFApSFA); ++i) {
tSFApSFA(i) = load_sfa && elem_less(get<0>(tSFAcSFA_compact(i)), get<0>(SFA_shape));
@ -495,22 +687,17 @@ struct CollectiveMma<
for (int i = 0; i < size(tSFBpSFB); ++i) {
tSFBpSFB(i) = load_sfb && elem_less(get<0>(tSFBcSFB_compact(i)), get<0>(SFB_shape));
}
//
// Copy gmem to smem for *k_tile_iter
//
int write_stage = smem_pipe_write.index();
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
// Copy operands A and B from global memory to shared memory
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
// Copy scale tensors from global memory to shared memory
copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage)));
copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage)));
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
if constexpr (!IsTmaLoadSFA) {
copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage)));
}
if constexpr (!IsTmaLoadSFB) {
copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage)));
}
if constexpr (!IsTmaLoadSFA || !IsTmaLoadSFB) {
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
}
++k_tile_iter;
@ -669,100 +856,32 @@ struct CollectiveMma<
Tensor tCrSFB = make_tensor_like<ElementBlockScale>(tCsSFB(_, _, _, _0{})); // (MMA,MMA_M,MMA_N)
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
}
else {
// Always zero out the accumulator for finest granularity
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_b = tCrSFB(_0{});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) {
filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b;
}
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
ElementBlockScale scale_a = tCrSFA(_0{});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) {
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
}
}
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrSFA(_0{});
scale_if_needed(accumulation, scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
scale_if_needed(accumulation, tCrSFA);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFB);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFA, tCrSFB);
}
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
}
else {
// Always zero out the accumulator for finest granularity
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
warpgroup_fence_operand(accumulation());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
warpgroup_fence_operand(accumulation());
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
@ -781,17 +900,9 @@ struct CollectiveMma<
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
}
}
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
warpgroup_wait<0>();
++smem_pipe_read;
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrSFA(_0{});
@ -806,19 +917,15 @@ struct CollectiveMma<
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFA, tCrSFB);
}
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
k_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count)
for ( ; k_tile_count > 1; --k_tile_count)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
@ -828,8 +935,32 @@ struct CollectiveMma<
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
}
else {
// Always zero out the accumulator for finest granularity
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_fence_operand(accumulation());
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
@ -848,7 +979,40 @@ struct CollectiveMma<
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
}
}
warpgroup_wait<0>();
pipeline.consumer_release(smem_pipe_release); // Unlock previous tile
++smem_pipe_read;
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrSFA(_0{});
scale_if_needed(accumulation, scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
scale_if_needed(accumulation, tCrSFA);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFB);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFA, tCrSFB);
}
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_release;
}
{
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
if constexpr (ScalePromotionInterval != 4) {
if (accumulation.prepare_if_needed()) {
@ -865,16 +1029,34 @@ struct CollectiveMma<
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation());
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_b = tCrSFB(_0{});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) {
filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b;
}
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
ElementBlockScale scale_a = tCrSFA(_0{});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) {
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
}
}
warpgroup_wait<0>();
pipeline.consumer_release(smem_pipe_release); // Unlock previous tile
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrSFA(_0{});
@ -889,28 +1071,21 @@ struct CollectiveMma<
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFA, tCrSFB);
}
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
if constexpr (ScalePromotionInterval != 4) {
// residues only exists when granularity is not the finnest
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
ElementBlockScale scale_ab = tCrSFA(_0{});
scale_if_needed(accumulation, scale_ab);
accumulation.scale_residue_if_needed(scale_ab);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
scale_if_needed(accumulation, tCrSFA);
accumulation.scale_residue_if_needed(tCrSFA);
}
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFB);
accumulation.scale_residue_if_needed(tCrSFB);
}
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
scale_if_needed(accumulation, tCrSFA, tCrSFB);
accumulation.scale_residue_if_needed(tCrSFA, tCrSFB);
}
}
@ -920,18 +1095,9 @@ struct CollectiveMma<
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release;
}
// The pipeline is not released in the first iteration
smem_pipe_release.advance(k_tile_count - 1);
pipeline.consumer_release(smem_pipe_release);
}
};

View File

@ -94,6 +94,13 @@ struct Has_SwapAB <T, CUTE_STL_NAMESPACE::void_t<decltype(T::SwapAB)>>
template <typename T>
static constexpr bool Has_SwapAB_v = Has_SwapAB<T>::value;
// additional producer warp role check for block scaling mainloop
template<typename T>
struct HasAuxiliaryLoad : cute::false_type{};
template <typename T>
static constexpr bool HasAuxiliaryLoad_v = HasAuxiliaryLoad<T>::value;
} // namespace kernel::detail
//////////////////////////////////////////////////////////////////////////////
@ -119,6 +126,7 @@ struct KernelPtrArrayTmaWarpSpecializedPingpong { };
// FP8 related policies (including Blocked Scaled Accumulation)
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { };
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { };
@ -302,13 +310,14 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecialized
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>,
"KernelSchedule must be one of the warp specialized policies");
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> ||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum>,
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm
@ -389,7 +398,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum
>
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
@ -399,7 +408,7 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum,
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
>,
"KernelSchedule must be one of the warp specialized policies");
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
};
@ -559,15 +568,14 @@ struct KernelTmaWarpSpecializedCooperativeSparseBlockScaledSm120 {
// Auxiliary Load Tag.
template<class Policy>
struct IsAuxiliaryLoadNeeded : cute::false_type{};
namespace kernel::detail {
template<
int Stages,
class ClusterShape,
class KernelSchedule
>
struct IsAuxiliaryLoadNeeded<
struct HasAuxiliaryLoad<
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<
Stages,
ClusterShape,
@ -575,6 +583,21 @@ struct IsAuxiliaryLoadNeeded<
>
> : cute::true_type{};
template<
int Stages,
class ClusterShape,
class KernelSchedule
>
struct HasAuxiliaryLoad<
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<
Stages,
ClusterShape,
KernelSchedule
>
> : cute::true_type{};
} // namespace kernel::detail
//////////////////////////////////////////////////////////////////////////////
//

View File

@ -165,7 +165,7 @@ public:
static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = IsAuxiliaryLoadNeeded<typename CollectiveMainloop::DispatchPolicy>::value;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;

View File

@ -166,7 +166,7 @@ public:
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = IsAuxiliaryLoadNeeded<typename CollectiveMainloop::DispatchPolicy>::value;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;

View File

@ -126,6 +126,7 @@ public:
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
/// Register requirement for Load and Math WGs
static constexpr int RegsPerThread =
@ -369,7 +370,7 @@ public:
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
MainloopAux = 3
};
@ -643,7 +644,53 @@ public:
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
}
else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
if constexpr (IsMainloopAuxiliaryLoadNeeded) {
while (work_tile_info.is_valid()) {
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info);
work_tile_info = next_work_tile_info;
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get the number of K tiles to compute for this work as well as the starting K tile offset of the work.
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
collective_mainloop.load_auxiliary(
params.mainloop,
mainloop_pipeline,
mainloop_pipe_producer_state,
load_inputs,
blk_coord,
k_tile_iter, work_k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Get next work tile
auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(
work_tile_info,
scheduler_pipeline,
scheduler_pipe_consumer_state
);
work_tile_info = next_work_tile_info;
} // Scheduler work fetch loop
}
}
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && is_epi_load_needed) {

View File

@ -130,9 +130,11 @@ public:
static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents;
static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp
static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
static constexpr bool IsMainloopAuxiliaryLoadNeeded = detail::HasAuxiliaryLoad_v<typename CollectiveMainloop::DispatchPolicy>;
static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads.");
static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total.");
@ -375,7 +377,7 @@ public:
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
MainloopAux = 3
};
// Kernel level shared memory storage
@ -453,6 +455,7 @@ public:
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.num_producers = NumProducerThreads;
mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{});
@ -684,6 +687,52 @@ public:
} // Mainloop Producer Warp End
else if (producer_warp_role == ProducerWarpRole::MainloopAux) {
if constexpr (IsMainloopAuxiliaryLoadNeeded) {
// Ensure that the prefetched kernel does not touch
// unflushed global memory prior to this instruction
cutlass::arch::wait_on_dependent_grids();
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load_auxiliary(
params.mainloop,
mainloop_pipeline,
mainloop_pipe_producer_state,
load_inputs,
blk_coord,
k_tile_iter, k_tile_count,
lane_idx,
block_rank_in_cluster,
shared_storage.tensors.mainloop
);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state);
if constexpr (IsSchedDynamicPersistent) {
auto [next_work_tile_info, increment_pipe] =
scheduler.fetch_next_work(
work_tile_info,
scheduler_pipeline,
scheduler_pipe_consumer_state
);
}
}
}
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) {