|
|
|
|
@ -105,6 +105,7 @@ struct CollectiveMma<
|
|
|
|
|
using ElementBlockScale = ElementAccumulator;
|
|
|
|
|
using GmemTiledCopyA = GmemTiledCopyA_;
|
|
|
|
|
using GmemTiledCopyB = GmemTiledCopyB_;
|
|
|
|
|
using GmemTiledCopyScaleTMA = cute::SM90_TMA_LOAD;
|
|
|
|
|
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
|
|
|
|
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
|
|
|
|
using SmemCopyAtomA = SmemCopyAtomA_;
|
|
|
|
|
@ -118,9 +119,6 @@ struct CollectiveMma<
|
|
|
|
|
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
|
|
|
|
using PipelineParams = typename MainloopPipeline::Params;
|
|
|
|
|
|
|
|
|
|
// 33 threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
|
|
|
|
|
static constexpr int NumProducerThreadEvents = 33;
|
|
|
|
|
|
|
|
|
|
static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{});
|
|
|
|
|
static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{});
|
|
|
|
|
static constexpr int ScaleGranularityK = size<1,0>(LayoutSFA{});
|
|
|
|
|
@ -133,6 +131,12 @@ struct CollectiveMma<
|
|
|
|
|
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
|
|
|
|
static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
|
|
|
|
|
|
|
|
|
|
static constexpr int ScaleTmaThreshold = 32;
|
|
|
|
|
static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold;
|
|
|
|
|
static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold;
|
|
|
|
|
// Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`)
|
|
|
|
|
static constexpr int NumProducerThreadEvents = ((IsTmaLoadSFA && IsTmaLoadSFB)? 1 : 33);
|
|
|
|
|
|
|
|
|
|
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
|
|
|
|
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
|
|
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
|
|
|
|
@ -191,10 +195,10 @@ struct CollectiveMma<
|
|
|
|
|
struct SharedStorage
|
|
|
|
|
{
|
|
|
|
|
struct TensorStorage : cute::aligned_struct<128> {
|
|
|
|
|
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
|
|
|
|
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
|
|
|
|
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutSFA>> smem_SFA; // ScaleMsPerTile x k
|
|
|
|
|
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutSFB>> smem_SFB; // ScaleNsPerTile x k
|
|
|
|
|
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // TILE_M x PIPE_K
|
|
|
|
|
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // TILE_N x PIPE_K
|
|
|
|
|
CUTE_ALIGNAS(128) cute::array<ElementBlockScale, cute::cosize_v<SmemLayoutSFA>> smem_SFA; // ScaleMsPerTile x PIPE_K
|
|
|
|
|
CUTE_ALIGNAS(128) cute::array<ElementBlockScale, cute::cosize_v<SmemLayoutSFB>> smem_SFB; // ScaleNsPerTile x PIPE_K
|
|
|
|
|
} tensors;
|
|
|
|
|
|
|
|
|
|
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
|
|
|
|
@ -218,29 +222,60 @@ struct CollectiveMma<
|
|
|
|
|
|
|
|
|
|
// Device side kernel params
|
|
|
|
|
struct Params {
|
|
|
|
|
static auto getTmaSFA() {
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
return make_tma_copy(
|
|
|
|
|
GmemTiledCopyScaleTMA{},
|
|
|
|
|
make_tensor(static_cast<ElementBlockScale const*>(nullptr), filter_zeros(LayoutSFA{})),
|
|
|
|
|
filter_zeros(SmemLayoutSFA{}(_,_,_0{})),
|
|
|
|
|
Shape<Int<ScaleMsPerTile>, Int<1>>{},
|
|
|
|
|
_1{});
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
static auto getTmaSFB() {
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
return make_tma_copy(
|
|
|
|
|
GmemTiledCopyScaleTMA{},
|
|
|
|
|
make_tensor(static_cast<ElementBlockScale const*>(nullptr), filter_zeros(LayoutSFB{})),
|
|
|
|
|
filter_zeros(SmemLayoutSFB{}(_,_,_0{})),
|
|
|
|
|
Shape<Int<ScaleNsPerTile>, Int<1>>{},
|
|
|
|
|
_1{});
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Assumption: StrideA is congruent with Problem_MK
|
|
|
|
|
using TMA_A = decltype(make_tma_copy_A_sm90(
|
|
|
|
|
GmemTiledCopyA{},
|
|
|
|
|
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
|
|
|
|
SmemLayoutA{}(_,_,0),
|
|
|
|
|
SmemLayoutA{}(_,_,_0{}),
|
|
|
|
|
TileShape{},
|
|
|
|
|
ClusterShape{}));
|
|
|
|
|
// Assumption: StrideB is congruent with Problem_NK
|
|
|
|
|
using TMA_B = decltype(make_tma_copy_B_sm90(
|
|
|
|
|
GmemTiledCopyB{},
|
|
|
|
|
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
|
|
|
|
SmemLayoutB{}(_,_,0),
|
|
|
|
|
SmemLayoutB{}(_,_,_0{}),
|
|
|
|
|
TileShape{},
|
|
|
|
|
ClusterShape{}));
|
|
|
|
|
// NOTE: Does make_tma_copy supports 0 stride?
|
|
|
|
|
using TMA_SFA = decltype(getTmaSFA());
|
|
|
|
|
using TMA_SFB = decltype(getTmaSFB());
|
|
|
|
|
TMA_A tma_load_a;
|
|
|
|
|
TMA_B tma_load_b;
|
|
|
|
|
TMA_SFA tma_load_sfa;
|
|
|
|
|
TMA_SFB tma_load_sfb;
|
|
|
|
|
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
|
|
|
|
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
|
|
|
|
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
|
|
|
|
// Block scaling factors for A and B
|
|
|
|
|
ElementBlockScale const* ptr_SFA;
|
|
|
|
|
LayoutSFA layout_SFA;
|
|
|
|
|
ElementBlockScale const* ptr_SFA;
|
|
|
|
|
ElementBlockScale const* ptr_SFB;
|
|
|
|
|
LayoutSFA layout_SFA;
|
|
|
|
|
LayoutSFB layout_SFB;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -259,7 +294,11 @@ struct CollectiveMma<
|
|
|
|
|
|
|
|
|
|
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
|
|
|
|
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
|
|
|
|
auto ptr_SFA = reinterpret_cast<ElementBlockScale const*>(args.ptr_SFA);
|
|
|
|
|
auto ptr_SFB = reinterpret_cast<ElementBlockScale const*>(args.ptr_SFB);
|
|
|
|
|
|
|
|
|
|
Tensor tensor_sfa = make_tensor(ptr_SFA, filter_zeros(args.layout_SFA));
|
|
|
|
|
Tensor tensor_sfb = make_tensor(ptr_SFB, filter_zeros(args.layout_SFB));
|
|
|
|
|
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
|
|
|
|
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
|
|
|
|
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
|
|
|
|
@ -274,20 +313,42 @@ struct CollectiveMma<
|
|
|
|
|
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
|
|
|
|
TileShape{},
|
|
|
|
|
ClusterShape{});
|
|
|
|
|
typename Params::TMA_SFA tma_load_sfa{};
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
tma_load_sfa = make_tma_copy(
|
|
|
|
|
GmemTiledCopyScaleTMA{},
|
|
|
|
|
tensor_sfa,
|
|
|
|
|
filter_zeros(SmemLayoutSFA{})(_,_,cute::Int<0>{}),
|
|
|
|
|
Shape<Int<ScaleMsPerTile>, Int<1>>{},
|
|
|
|
|
_1{});
|
|
|
|
|
}
|
|
|
|
|
typename Params::TMA_SFB tma_load_sfb{};
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
tma_load_sfb = make_tma_copy(
|
|
|
|
|
GmemTiledCopyScaleTMA{},
|
|
|
|
|
tensor_sfb,
|
|
|
|
|
filter_zeros(SmemLayoutSFB{})(_,_,cute::Int<0>{}),
|
|
|
|
|
Shape<Int<ScaleNsPerTile>, Int<1>>{},
|
|
|
|
|
_1{});
|
|
|
|
|
}
|
|
|
|
|
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
|
|
|
|
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
|
|
|
|
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
|
|
|
|
uint32_t transaction_bytes_sfa = TmaTransactionBytesSFA;
|
|
|
|
|
uint32_t transaction_bytes_sfb = TmaTransactionBytesSFB;
|
|
|
|
|
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk + transaction_bytes_sfa + transaction_bytes_sfb;
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
tma_load_a,
|
|
|
|
|
tma_load_b,
|
|
|
|
|
tma_load_sfa,
|
|
|
|
|
tma_load_sfb,
|
|
|
|
|
transaction_bytes,
|
|
|
|
|
transaction_bytes_mk,
|
|
|
|
|
transaction_bytes_nk,
|
|
|
|
|
args.ptr_SFA,
|
|
|
|
|
args.layout_SFA,
|
|
|
|
|
args.ptr_SFB,
|
|
|
|
|
args.layout_SFB
|
|
|
|
|
args.layout_SFA,
|
|
|
|
|
args.layout_SFB,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -302,20 +363,39 @@ struct CollectiveMma<
|
|
|
|
|
|
|
|
|
|
bool implementable = true;
|
|
|
|
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
|
|
|
|
if (!cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{})) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor A.\n");
|
|
|
|
|
}
|
|
|
|
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
|
|
|
|
if (!cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{})) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load tensor B.\n");
|
|
|
|
|
}
|
|
|
|
|
constexpr int min_tma_aligned_elements_S = tma_alignment_bits / cutlass::sizeof_bits<ElementBlockScale>::value;
|
|
|
|
|
if (IsTmaLoadSFA && !cutlass::detail::check_alignment<min_tma_aligned_elements_S>(args.layout_SFA)) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale A.\n");
|
|
|
|
|
}
|
|
|
|
|
if (IsTmaLoadSFB && !cutlass::detail::check_alignment<min_tma_aligned_elements_S>(args.layout_SFB)) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */
|
|
|
|
|
constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{});
|
|
|
|
|
implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval);
|
|
|
|
|
implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval);
|
|
|
|
|
if (args.mma_promotion_interval % 4 != 0 ||
|
|
|
|
|
args.mma_promotion_interval != ScalePromotionInterval ||
|
|
|
|
|
args.mma_promotion_interval % pipe_k != 0 ||
|
|
|
|
|
pipe_k > args.mma_promotion_interval) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Argument mma_promotion_interval is invalid.\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// We expect full tiles in K
|
|
|
|
|
implementable = implementable && (K % size<2>(TileShape{}) == 0);
|
|
|
|
|
|
|
|
|
|
if (!implementable) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
if (K % size<2>(TileShape{}) != 0) {
|
|
|
|
|
implementable = false;
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n");
|
|
|
|
|
}
|
|
|
|
|
return implementable;
|
|
|
|
|
}
|
|
|
|
|
@ -326,7 +406,12 @@ struct CollectiveMma<
|
|
|
|
|
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
|
|
|
|
static constexpr uint32_t TmaTransactionBytesNK =
|
|
|
|
|
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
|
|
|
|
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
|
|
|
|
|
|
|
|
|
static constexpr uint32_t TmaTransactionBytesSFA =
|
|
|
|
|
(IsTmaLoadSFA? cutlass::bits_to_bytes(ScaleMsPerTile * static_cast<uint32_t>(sizeof_bits<ElementBlockScale>::value)): 0);
|
|
|
|
|
static constexpr uint32_t TmaTransactionBytesSFB =
|
|
|
|
|
(IsTmaLoadSFB? cutlass::bits_to_bytes(ScaleNsPerTile * static_cast<uint32_t>(sizeof_bits<ElementBlockScale>::value)): 0);
|
|
|
|
|
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesSFA + TmaTransactionBytesSFB;
|
|
|
|
|
|
|
|
|
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
|
|
|
|
CUTLASS_DEVICE
|
|
|
|
|
@ -334,6 +419,12 @@ struct CollectiveMma<
|
|
|
|
|
{
|
|
|
|
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
|
|
|
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfa.get_tma_descriptor());
|
|
|
|
|
}
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
cute::prefetch_tma_descriptor(mainloop_params.tma_load_sfb.get_tma_descriptor());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Set up the data needed by this collective for load and mma.
|
|
|
|
|
@ -357,10 +448,24 @@ struct CollectiveMma<
|
|
|
|
|
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
|
|
|
|
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
|
|
|
|
|
|
|
|
|
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
|
|
|
|
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
|
|
|
|
Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (scale_m,k,l)
|
|
|
|
|
Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (scale_n,k,l)
|
|
|
|
|
// Note that mSFA_mkl and mSFB_nkl are already blocked tiled in the `m` host and
|
|
|
|
|
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mSFA_mkl and mSFB_nkl.
|
|
|
|
|
auto mSFA_mkl = [&]() {
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
return mainloop_params.tma_load_sfa.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFA)));
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFA), mainloop_params.layout_SFA); // (scale_m,k,l)
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
auto mSFB_nkl = [&]() {
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
return mainloop_params.tma_load_sfb.get_tma_tensor(shape(filter_zeros(mainloop_params.layout_SFB)));
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return make_tensor(make_gmem_ptr(mainloop_params.ptr_SFB), mainloop_params.layout_SFB); // (scale_n,k,l)
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
|
|
|
|
|
return cute::make_tuple(gA_mkl, gB_nkl, mSFA_mkl, mSFB_nkl);
|
|
|
|
|
}
|
|
|
|
|
@ -387,8 +492,8 @@ struct CollectiveMma<
|
|
|
|
|
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
|
|
|
|
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
|
|
|
|
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
|
|
|
|
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (BLK_M,BLK_K,PIPE)
|
|
|
|
|
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (BLK_M,BLK_K,PIPE)
|
|
|
|
|
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), filter_zeros(SmemLayoutSFA{})); // (ScaleMsPerTile,PIPE)
|
|
|
|
|
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), filter_zeros(SmemLayoutSFB{})); // (ScaleNsPerTile,PIPE)
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Prepare the TMA loads for A and B
|
|
|
|
|
@ -399,6 +504,8 @@ struct CollectiveMma<
|
|
|
|
|
|
|
|
|
|
Tensor gA_mkl = get<0>(load_inputs);
|
|
|
|
|
Tensor gB_nkl = get<1>(load_inputs);
|
|
|
|
|
Tensor mSFA_mkl = get<2>(load_inputs);
|
|
|
|
|
Tensor mSFB_nkl = get<3>(load_inputs);
|
|
|
|
|
|
|
|
|
|
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
|
|
|
|
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
|
|
|
|
@ -407,9 +514,120 @@ struct CollectiveMma<
|
|
|
|
|
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
|
|
|
|
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
|
|
|
|
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
|
|
|
|
Tensor gSFA = local_tile(
|
|
|
|
|
mSFA_mkl, make_tile(Int<ScaleMsPerTile>{}, Int<1>{}),
|
|
|
|
|
make_coord(m_coord,_,l_coord));
|
|
|
|
|
Tensor gSFB = local_tile(
|
|
|
|
|
mSFB_nkl, make_tile(Int<ScaleNsPerTile>{}, Int<1>{}),
|
|
|
|
|
make_coord(n_coord,_,l_coord));
|
|
|
|
|
|
|
|
|
|
// Applies the mapping from block_tma_a
|
|
|
|
|
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
|
|
|
|
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
|
|
|
|
|
|
|
|
|
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
|
|
|
|
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
|
|
|
|
|
|
|
|
|
auto [tAgA_SFA, tAsA_SFA] = [&]() {
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
auto block_tma_sfa = mainloop_params.tma_load_sfa.get_slice(cluster_local_block_id.y);
|
|
|
|
|
Tensor tAgA_SFA_ = block_tma_sfa.partition_S(gSFA);
|
|
|
|
|
Tensor tAsA_SFA_ = block_tma_sfa.partition_D(sSFA);
|
|
|
|
|
return cute::make_tuple(tAgA_SFA_, tAsA_SFA_);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return cute::make_tuple(0, 0);
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
auto [tBgB_SFB, tBsB_SFB] = [&]() {
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
auto block_tma_sfb = mainloop_params.tma_load_sfb.get_slice(cluster_local_block_id.y);
|
|
|
|
|
Tensor tBgB_SFB_ = block_tma_sfb.partition_S(gSFB);
|
|
|
|
|
Tensor tBsB_SFB_ = block_tma_sfb.partition_D(sSFB);
|
|
|
|
|
return cute::make_tuple(tBgB_SFB_, tBsB_SFB_);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return cute::make_tuple(0, 0);
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
|
|
|
|
|
uint16_t mcast_mask_a = 0;
|
|
|
|
|
uint16_t mcast_mask_b = 0;
|
|
|
|
|
uint16_t mcast_mask_sf = 0;
|
|
|
|
|
|
|
|
|
|
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
|
|
|
|
// Maps the tile -> block, value
|
|
|
|
|
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
|
|
|
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
|
|
|
for (int n = 0; n < size<1>(block_layout); ++n) {
|
|
|
|
|
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
|
|
|
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
|
|
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
|
|
|
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Mainloop
|
|
|
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
|
|
|
for ( ; k_tile_count > 0; --k_tile_count) {
|
|
|
|
|
// LOCK smem_pipe_write for _writing_
|
|
|
|
|
pipeline.producer_acquire(smem_pipe_write);
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Copy gmem to smem for *k_tile_iter
|
|
|
|
|
//
|
|
|
|
|
int write_stage = smem_pipe_write.index();
|
|
|
|
|
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
|
|
|
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
|
|
|
|
|
|
|
|
|
// Copy operands A and B from global memory to shared memory
|
|
|
|
|
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
|
|
|
|
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
|
|
|
|
|
|
|
|
|
// Copy scale tensors from global memory to shared memory
|
|
|
|
|
if constexpr (IsTmaLoadSFA) {
|
|
|
|
|
if (lane_predicate) {
|
|
|
|
|
copy(mainloop_params.tma_load_sfa.with(*tma_barrier, mcast_mask_sf), tAgA_SFA(_,_,_,*k_tile_iter), tAsA_SFA(_,_,_,write_stage));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if constexpr (IsTmaLoadSFB) {
|
|
|
|
|
if (lane_predicate) {
|
|
|
|
|
copy(mainloop_params.tma_load_sfb.with(*tma_barrier, mcast_mask_sf), tBgB_SFB(_,_,_,*k_tile_iter), tBsB_SFB(_,_,_,write_stage));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
++k_tile_iter;
|
|
|
|
|
|
|
|
|
|
// Advance smem_pipe_write
|
|
|
|
|
++smem_pipe_write;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <
|
|
|
|
|
class TensorA, class TensorB,
|
|
|
|
|
class TensorScaleA, class TensorScaleB,
|
|
|
|
|
class KTileIterator, class BlockCoord
|
|
|
|
|
>
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
|
|
load_auxiliary(
|
|
|
|
|
Params const& mainloop_params,
|
|
|
|
|
MainloopPipeline pipeline,
|
|
|
|
|
PipelineState smem_pipe_write,
|
|
|
|
|
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
|
|
|
|
BlockCoord const& blk_coord,
|
|
|
|
|
KTileIterator k_tile_iter, int k_tile_count,
|
|
|
|
|
int thread_idx,
|
|
|
|
|
uint32_t block_rank_in_cluster,
|
|
|
|
|
TensorStorage& shared_tensors) {
|
|
|
|
|
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
|
|
|
|
Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFA.data()), SmemLayoutSFA{}); // (ScaleMsPerTile,k)
|
|
|
|
|
Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_SFB.data()), SmemLayoutSFB{}); // (ScaleNsPerTile,k)
|
|
|
|
|
|
|
|
|
|
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
|
|
|
|
|
|
|
|
|
Tensor mSFA_mkl = get<2>(load_inputs);
|
|
|
|
|
Tensor mSFB_nkl = get<3>(load_inputs);
|
|
|
|
|
|
|
|
|
|
@ -441,35 +659,9 @@ struct CollectiveMma<
|
|
|
|
|
Tensor tSFBcSFB_k = thr_scale_copy_b.partition_S(cSFB_k);
|
|
|
|
|
Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB);
|
|
|
|
|
|
|
|
|
|
// Applies the mapping from block_tma_a
|
|
|
|
|
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
|
|
|
|
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
|
|
|
|
|
|
|
|
|
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
|
|
|
|
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
|
|
|
|
|
|
|
|
|
Tensor tSFApSFA = make_tensor<bool>(shape(filter_zeros(tSFAsSFA(_,_,_,_0{})))); // (CPY,CPY_M,CPY_K)
|
|
|
|
|
Tensor tSFBpSFB = make_tensor<bool>(shape(filter_zeros(tSFBsSFB(_,_,_,_0{})))); // (CPY,CPY_N,CPY_K)
|
|
|
|
|
|
|
|
|
|
uint16_t mcast_mask_a = 0;
|
|
|
|
|
uint16_t mcast_mask_b = 0;
|
|
|
|
|
|
|
|
|
|
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
|
|
|
|
// Maps the tile -> block, value
|
|
|
|
|
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
|
|
|
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
|
|
|
for (int n = 0; n < size<1>(block_layout); ++n) {
|
|
|
|
|
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
|
|
|
|
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
|
|
|
|
for (int m = 0; m < size<0>(block_layout); ++m) {
|
|
|
|
|
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto SFA_shape = shape(mainloop_params.layout_SFA);
|
|
|
|
|
auto SFB_shape = shape(mainloop_params.layout_SFB);
|
|
|
|
|
|
|
|
|
|
@ -480,9 +672,9 @@ struct CollectiveMma<
|
|
|
|
|
pipeline.producer_acquire(smem_pipe_write);
|
|
|
|
|
|
|
|
|
|
// Since scale granularity K is multiple of BLK_K we do not have to consider if that is OOB
|
|
|
|
|
bool load_sfa = thread_idx < ScaleMsPerTile;
|
|
|
|
|
Tensor tSFAcSFA = tSFAcSFA_k(_,_,_,*k_tile_iter);
|
|
|
|
|
Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA);
|
|
|
|
|
bool load_sfa = thread_idx < ScaleMsPerTile;
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int i = 0; i < size(tSFApSFA); ++i) {
|
|
|
|
|
tSFApSFA(i) = load_sfa && elem_less(get<0>(tSFAcSFA_compact(i)), get<0>(SFA_shape));
|
|
|
|
|
@ -495,22 +687,17 @@ struct CollectiveMma<
|
|
|
|
|
for (int i = 0; i < size(tSFBpSFB); ++i) {
|
|
|
|
|
tSFBpSFB(i) = load_sfb && elem_less(get<0>(tSFBcSFB_compact(i)), get<0>(SFB_shape));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Copy gmem to smem for *k_tile_iter
|
|
|
|
|
//
|
|
|
|
|
int write_stage = smem_pipe_write.index();
|
|
|
|
|
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
|
|
|
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
|
|
|
|
|
|
|
|
|
// Copy operands A and B from global memory to shared memory
|
|
|
|
|
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
|
|
|
|
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
|
|
|
|
|
|
|
|
|
// Copy scale tensors from global memory to shared memory
|
|
|
|
|
copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage)));
|
|
|
|
|
copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage)));
|
|
|
|
|
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
|
|
|
|
if constexpr (!IsTmaLoadSFA) {
|
|
|
|
|
copy_if(scale_copy_a, tSFApSFA, filter_zeros(tSFAgSFA_k(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,write_stage)));
|
|
|
|
|
}
|
|
|
|
|
if constexpr (!IsTmaLoadSFB) {
|
|
|
|
|
copy_if(scale_copy_b, tSFBpSFB, filter_zeros(tSFBgSFB_k(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,write_stage)));
|
|
|
|
|
}
|
|
|
|
|
if constexpr (!IsTmaLoadSFA || !IsTmaLoadSFB) {
|
|
|
|
|
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++k_tile_iter;
|
|
|
|
|
|
|
|
|
|
@ -669,100 +856,32 @@ struct CollectiveMma<
|
|
|
|
|
Tensor tCrSFB = make_tensor_like<ElementBlockScale>(tCsSFB(_, _, _, _0{})); // (MMA,MMA_M,MMA_N)
|
|
|
|
|
|
|
|
|
|
// Prologue GMMAs
|
|
|
|
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
|
|
|
|
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
|
|
|
|
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
|
|
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
|
|
|
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
|
|
|
|
|
{
|
|
|
|
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
|
|
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
|
|
|
|
|
|
|
|
if constexpr (ScalePromotionInterval != 4) {
|
|
|
|
|
if (accumulation.prepare_if_needed()) {
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
// Always zero out the accumulator for finest granularity
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int read_stage = smem_pipe_read.index();
|
|
|
|
|
|
|
|
|
|
// Load per block scale values from shared memory to registers
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
|
|
|
|
|
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_b = tCrSFB(_0{});
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) {
|
|
|
|
|
filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
ElementBlockScale scale_a = tCrSFA(_0{});
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) {
|
|
|
|
|
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
|
|
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
|
|
|
|
|
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_ab = tCrSFA(_0{});
|
|
|
|
|
scale_if_needed(accumulation, scale_ab);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++smem_pipe_read;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_tile_prologue = prologue_mma_count - 1; k_tile_prologue > 0; --k_tile_prologue)
|
|
|
|
|
{
|
|
|
|
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
|
|
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
|
|
|
|
|
|
|
|
if constexpr (ScalePromotionInterval != 4) {
|
|
|
|
|
if (accumulation.prepare_if_needed()) {
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
// Always zero out the accumulator for finest granularity
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int read_stage = smem_pipe_read.index();
|
|
|
|
|
|
|
|
|
|
// Load per block scale values from shared memory to registers
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
|
|
|
|
|
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
|
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
|
|
|
|
|
@ -781,17 +900,9 @@ struct CollectiveMma<
|
|
|
|
|
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
|
|
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
|
|
|
|
|
warpgroup_wait<0>();
|
|
|
|
|
++smem_pipe_read;
|
|
|
|
|
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_ab = tCrSFA(_0{});
|
|
|
|
|
@ -806,19 +917,15 @@ struct CollectiveMma<
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
++smem_pipe_read;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
// Mainloop GMMAs
|
|
|
|
|
k_tile_count -= prologue_mma_count;
|
|
|
|
|
k_tile_count -= 1;
|
|
|
|
|
|
|
|
|
|
CUTLASS_PRAGMA_NO_UNROLL
|
|
|
|
|
for ( ; k_tile_count > 0; --k_tile_count)
|
|
|
|
|
for ( ; k_tile_count > 1; --k_tile_count)
|
|
|
|
|
{
|
|
|
|
|
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
|
|
|
|
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
@ -828,8 +935,32 @@ struct CollectiveMma<
|
|
|
|
|
int read_stage = smem_pipe_read.index();
|
|
|
|
|
|
|
|
|
|
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{},read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{},read_stage)), tCrSFB);
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
|
|
|
|
|
|
|
|
|
|
if constexpr (ScalePromotionInterval != 4) {
|
|
|
|
|
if (accumulation.prepare_if_needed()) {
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
// Always zero out the accumulator for finest granularity
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
|
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
|
|
|
|
|
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
|
|
|
|
|
@ -848,7 +979,40 @@ struct CollectiveMma<
|
|
|
|
|
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
warpgroup_wait<0>();
|
|
|
|
|
pipeline.consumer_release(smem_pipe_release); // Unlock previous tile
|
|
|
|
|
++smem_pipe_read;
|
|
|
|
|
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
|
|
|
|
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_ab = tCrSFA(_0{});
|
|
|
|
|
scale_if_needed(accumulation, scale_ab);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Advance smem_pipe_read and smem_pipe_release
|
|
|
|
|
++smem_pipe_release;
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Compute on k_tile
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
int read_stage = smem_pipe_read.index();
|
|
|
|
|
|
|
|
|
|
// Load per block scale values from shared memory to registers (at most twice per block along M and/or N)
|
|
|
|
|
copy(tCsSFA(_,_,_,make_coord(_0{}, read_stage)), tCrSFA);
|
|
|
|
|
copy(tCsSFB(_,_,_,make_coord(_0{}, read_stage)), tCrSFB);
|
|
|
|
|
|
|
|
|
|
if constexpr (ScalePromotionInterval != 4) {
|
|
|
|
|
if (accumulation.prepare_if_needed()) {
|
|
|
|
|
@ -865,16 +1029,34 @@ struct CollectiveMma<
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
|
|
|
|
// (V,M,K) x (V,N,K) => (V,M,N)
|
|
|
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
}
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
|
|
|
|
|
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
|
|
|
|
warpgroup_wait<K_PIPE_MMAS>();
|
|
|
|
|
warpgroup_fence_operand(accumulation());
|
|
|
|
|
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
tCrSFA(_0{}) = tCrSFA(_0{}) * tCrSFB(_0{});
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_b = tCrSFB(_0{});
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int i = 0; i < size(filter_zeros(tCrSFA)); i++) {
|
|
|
|
|
filter_zeros(tCrSFA)(i) = filter_zeros(tCrSFA)(i) * scale_b;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
ElementBlockScale scale_a = tCrSFA(_0{});
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int i = 0; i < size(filter_zeros(tCrSFB)); i++) {
|
|
|
|
|
filter_zeros(tCrSFB)(i) = filter_zeros(tCrSFB)(i) * scale_a;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
warpgroup_wait<0>();
|
|
|
|
|
pipeline.consumer_release(smem_pipe_release); // Unlock previous tile
|
|
|
|
|
// Block scale the accumulators with reg tensor `tCrSFA` and `tCrSFB`
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_ab = tCrSFA(_0{});
|
|
|
|
|
@ -889,28 +1071,21 @@ struct CollectiveMma<
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
|
|
|
|
|
|
|
|
|
// Advance smem_pipe_read and smem_pipe_release
|
|
|
|
|
++smem_pipe_read;
|
|
|
|
|
++smem_pipe_release;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (ScalePromotionInterval != 4) {
|
|
|
|
|
// residues only exists when granularity is not the finnest
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
ElementBlockScale scale_ab = tCrSFA(_0{});
|
|
|
|
|
scale_if_needed(accumulation, scale_ab);
|
|
|
|
|
accumulation.scale_residue_if_needed(scale_ab);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA);
|
|
|
|
|
accumulation.scale_residue_if_needed(tCrSFA);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFB);
|
|
|
|
|
accumulation.scale_residue_if_needed(tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) {
|
|
|
|
|
scale_if_needed(accumulation, tCrSFA, tCrSFB);
|
|
|
|
|
accumulation.scale_residue_if_needed(tCrSFA, tCrSFB);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -920,18 +1095,9 @@ struct CollectiveMma<
|
|
|
|
|
/// Perform a Consumer Epilogue to release all buffers
|
|
|
|
|
CUTLASS_DEVICE void
|
|
|
|
|
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
|
|
|
|
// Prologue GMMAs
|
|
|
|
|
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
|
|
|
|
k_tile_count -= prologue_mma_count;
|
|
|
|
|
|
|
|
|
|
smem_pipe_release.advance(k_tile_count);
|
|
|
|
|
// Wait on all GMMAs to complete
|
|
|
|
|
warpgroup_wait<0>();
|
|
|
|
|
|
|
|
|
|
for (int count = 0; count < prologue_mma_count; ++count) {
|
|
|
|
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
|
|
|
|
++smem_pipe_release;
|
|
|
|
|
}
|
|
|
|
|
// The pipeline is not released in the first iteration
|
|
|
|
|
smem_pipe_release.advance(k_tile_count - 1);
|
|
|
|
|
pipeline.consumer_release(smem_pipe_release);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|