v4.2 tag release. (#2638)

This commit is contained in:
Junkai-Wu
2025-09-16 00:21:53 +08:00
committed by GitHub
parent 56f0718a97
commit 6a35b4d22f
161 changed files with 14056 additions and 3793 deletions

View File

@ -539,7 +539,8 @@ make_cotiled_copy(Copy_Atom<Args...> const& copy_atom,
auto layout_tv_data = composition(inv_data_layout, atom_tv_layout);
// Check validity
CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)),
// Append 1:0 to data_layout so that OOB coordinates get the stride-0
CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)),
"The memory pointed to by AtomTVLayout does not exist in the DataLayout.");
//
// Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them

View File

@ -705,7 +705,7 @@ public:
auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info);
auto&& scales = cute::get<1>(partitioned_transform_extra_info);
using ScaleType = decltype(scales);
auto tSrS = make_tensor(static_cast<ScaleType&&>(scales).data(), scales.layout());
auto tSrS = make_tensor(scales.data(), scales.layout());
auto tSsS = cute::get<2>(partitioned_transform_extra_info);
copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS);
@ -714,7 +714,7 @@ public:
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
auto&& zeros = cute::get<3>(partitioned_transform_extra_info);
using ZeroType = decltype(zeros);
auto tZrZ = make_tensor(static_cast<ZeroType&&>(zeros).data(), zeros.layout());
auto tZrZ = make_tensor(zeros.data(), zeros.layout());
auto tZsZ = cute::get<4>(partitioned_transform_extra_info);
copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ);
@ -1002,7 +1002,6 @@ public:
auto src_arr = recast<SrcArray>(src);
auto dst_arr = recast<DstArray>(dst);
Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, pack));
Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack));
cute::transform(src_arr, dst_arr, Converter::convert);
@ -1019,7 +1018,6 @@ public:
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>){
Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack));
Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack));
for (int i = 0; i < size<1>(dst_vm); ++i){
@ -1194,13 +1192,7 @@ public:
Tensor tCsS = cta_mma.partition_A(sS);
Tensor tSsS = smem_thr_copy_S.partition_S(tCsS);
Tensor tSrS = make_tensor<ElementScale>(tSsS(_,_,_,_,0).shape());
#if 0
if(cute::thread(128, 0)){
print("sS: ");print(sS);print("\n");
print("tSsS: ");print(tSsS);print("\n");
print("tSrS: ");print(tSrS);print("\n");
}
#endif
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS);
}
@ -1209,16 +1201,6 @@ public:
Tensor tCsZ = cta_mma.partition_A(sZ);
Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ);
Tensor tZrZ = make_tensor<ElementZero>(tZsZ(_,_,_,_,0).shape());
#if 0
if(cute::thread(128, 0)){
print("sS: ");print(sS);print("\n");
print("tSsS: ");print(tSsS);print("\n");
print("tSrS: ");print(tSrS);print("\n");
print("sZ: ");print(sZ);print("\n");
print("tZsZ: ");print(tZsZ);print("\n");
print("tZrZ: ");print(tZrZ);print("\n");
}
#endif
return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ);
}
else {

View File

@ -582,7 +582,13 @@ sm100_sparse_get_tma_dispatch_policy() {
* Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued,
* subject to the constraint of the provided per-warp tmem subpartition shape
**/
template<class GmemStrideTypeD, class ElementAccumulator, class ElementD, class TmemShape_MN, class FusionOp>
template<
class GmemStrideTypeD,
class ElementAccumulator,
class ElementD,
class TmemShape_MN,
bool IsBlockScaleSupported
>
constexpr auto
sm100_get_tmem_load_op() {
using namespace cute;
@ -958,6 +964,172 @@ struct CallbacksBuilder<
>;
};
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
template<
class OpClass,
class CtaTileShape_MNK,
class EpilogueTileType,
class TmemWarpShape_MN,
class ElementC_,
class GmemStrideTypeC,
class ElementD,
class GmemStrideTypeD,
bool IsPerColScaleSupported
>
static constexpr auto
sm100_dense_compute_tile_shape_or_override() {
using namespace cute;
static_assert(!cute::is_same_v<OpClass, arch::OpClassSparseTensorOp> && !cute::is_same_v<OpClass, arch::OpClassBlockScaledSparseTensorOp>);
constexpr bool DisableSource = cute::is_void_v<ElementC_>;
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>;
if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
size<1>(CtaTileShape_MNK{}) == 256) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int DpFull = 32;
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
// Note:
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
// This is a general workable epi_tile_N which does not promise best perf.
return make_tile(Int<M>{}, Int<128>{});
}
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
// Epilogues w/o residual load are less sensitive to smem allocation
// Target a fixed amount of compute per epilogue iteration
if (DisableSource) {
if (MaxBits == 4) {
// Make epilogue tile larger to reduce the epilogue iterations.
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
constexpr int ComputeElts = 8192;
return ComputeElts / M;
}
constexpr int ComputeElts = 4096;
return ComputeElts / M;
}
// Epilogues w/ residual load are more sensitive to smem allocation
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
else {
if (MaxBits == 32) {
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
}
// Per-column scaling is high register pressure, reduce tile to prevent spills
else if (IsPerColScaleSupported) {
return 32;
}
else if (MaxBits == 16) {
return (CtaN <= 128) ? 32 : 64;
}
else {
return 64;
}
}
}();
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementC> * WarpN;
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementD> * WarpN;
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
// stride by tmem warp layout and return a by-mode tiler
auto tile_m = Layout<Int<M>>{};
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
return make_tile(tile_m, coalesce(tile_n));
}
else {
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
"EpilogueTile must be a cute::Tile or cute::Shape");
EpilogueTileType epi_tile;
constexpr int M = size<0>(shape(epi_tile));
constexpr int N = size<1>(shape(epi_tile));
static_assert(N % 8 == 0, "Unsupported tile shape");
return epi_tile;
}
}
template<
bool Is2SmMma,
class MmaTileShape_MNK
>
static constexpr auto
sm100_tmem_warps() {
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
return Shape<_2,_2>{};
}
else {
return Shape<_4,_1>{};
}
}
template<
bool Is2SmMma,
class MmaTileShape_MNK
>
static constexpr auto
sm100_cta_tile_shape() {
if constexpr (Is2SmMma) { // 2x1 threadblock shape
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
}
else { // 1x1 threadblock shape
return MmaTileShape_MNK{};
}
}
template<
class EpilogueScheduleType,
class ElementC_,
class ElementD,
int EpiTiles,
int FragmentSize
>
static constexpr auto
sm100_dense_dispatch_policy() {
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
// TMA store delay performs worse with residual loads
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
constexpr int StagesD = cute::min(EpiTiles, 2);
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
: cute::min(EpiTiles, 4);
if constexpr (is_base_of_v<PtrArrayNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
is_base_of_v<PtrArrayNoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
return Sm100PtrArrayNoSmemWarpSpecialized{};
}
else if constexpr (is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType> || is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
return Sm100NoSmemWarpSpecialized{};
}
else if constexpr (is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized1Sm> ||
is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>) {
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
}
else {
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
}
}
// Helper for building TMA warp-specialized collective epilogues, specialized by
// the fusion operation performed and the dispatch policy to use.
template <
@ -1017,17 +1189,7 @@ private:
}
}
using CtaTileShape_MNK = decltype(cta_tile_shape());
static constexpr auto
tmem_warps() {
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
return Shape<_2,_2>{};
}
else {
return Shape<_4,_1>{};
}
}
using TmemWarpShape_MN = decltype(tmem_warps());
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
static constexpr auto
@ -1041,84 +1203,10 @@ private:
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, Schedule,
FusionOp>();
}
else if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
size<1>(CtaTileShape_MNK{}) == 256) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int DpFull = 32;
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
// Note:
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
// This is a general workable epi_tile_N which does not promise best perf.
return make_tile(Int<M>{}, Int<128>{});
}
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
// Epilogues w/o residual load are less sensitive to smem allocation
// Target a fixed amount of compute per epilogue iteration
if (DisableSource) {
if (MaxBits == 4) {
// Make epilogue tile larger to reduce the epilogue iterations.
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
constexpr int ComputeElts = 8192;
return ComputeElts / M;
}
constexpr int ComputeElts = 4096;
return ComputeElts / M;
}
// Epilogues w/ residual load are more sensitive to smem allocation
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
else {
if (MaxBits == 32) {
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
}
// Per-column scaling is high register pressure, reduce tile to prevent spills
else if (FusionOp::IsPerColScaleSupported) {
return 32;
}
else if (MaxBits == 16) {
return (CtaN <= 128) ? 32 : 64;
}
else {
return 64;
}
}
}();
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementC> * WarpN;
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
: 128 / sizeof_bits_v<ElementD> * WarpN;
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
// stride by tmem warp layout and return a by-mode tiler
auto tile_m = Layout<Int<M>>{};
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
return make_tile(tile_m, coalesce(tile_n));
}
else {
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
"EpilogueTile must be a cute::Tile or cute::Shape");
EpilogueTileType epi_tile;
constexpr int M = size<0>(shape(epi_tile));
constexpr int N = size<1>(shape(epi_tile));
static_assert(N % 8 == 0, "Unsupported tile shape");
return epi_tile;
return sm100_dense_compute_tile_shape_or_override<
OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN,
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp::IsPerColScaleSupported>();
}
}
using EpilogueTile_MN = decltype(epilogue_tile());
@ -1129,30 +1217,18 @@ private:
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
FusionOp::IsBlockScaleSupported
>());
static constexpr auto
dispatch_policy() {
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
// TMA store delay performs worse with residual loads
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
constexpr int StagesD = cute::min(EpiTiles, 2);
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
: cute::min(EpiTiles, 4);
if constexpr (is_same_v<OpClass, arch::OpClassSparseTensorOp> ||
is_same_v<OpClass, arch::OpClassBlockScaledSparseTensorOp>) {
return detail::sparse::sm100_sparse_get_tma_dispatch_policy<CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>();
}
else if constexpr (is_same_v<Schedule, PtrArrayTmaWarpSpecialized1Sm> ||
is_same_v<Schedule, PtrArrayTmaWarpSpecialized2Sm>) {
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
}
else {
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
return detail::sm100_dense_dispatch_policy<Schedule, ElementC_, ElementD, EpiTiles, FragmentSize>();
}
}
@ -1228,6 +1304,87 @@ public:
>;
};
template<
class OpClass,
class MmaTileShape_MNK,
class EpilogueTileType,
class ElementAccumulator_,
class ElementC,
class ElementD,
class Schedule,
class GmemStrideTypeC,
class GmemStrideTypeD,
bool IsPerColScaleSupported,
bool IsBlockScaleSupported
>
struct Sm100EpilogueDescriptor {
using ElementAccumulator = ElementAccumulator_;
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule> || is_base_of_v<NoSmemWarpSpecialized2Sm, Schedule>;
using CtaTileShape_MNK = decltype(sm100_cta_tile_shape<Is2SmMma, MmaTileShape_MNK>());
using TileShape = CtaTileShape_MNK;
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
using EpilogueTile = decltype(
sm100_dense_compute_tile_shape_or_override<OpClass, CtaTileShape_MNK, EpilogueTileType,
TmemWarpShape_MN, ElementC, GmemStrideTypeC, ElementD, GmemStrideTypeD, IsPerColScaleSupported>()
);
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{})));
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
using DispatchPolicy = decltype(sm100_dense_dispatch_policy<Schedule, ElementC, ElementD, EpiTiles, FragmentSize>());
constexpr static int StagesC = DispatchPolicy::StagesC;
constexpr static int StagesD = DispatchPolicy::StagesD;
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
IsBlockScaleSupported
>());
};
// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node
template<
typename EpilogueDescriptor,
typename StrideOrLayoutTag,
typename ElementAux
>
struct Sm100AuxLoadDescriptor {
constexpr static int Stages = EpilogueDescriptor::StagesC;
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
using Element = ElementAux;
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<
Stride, ElementAux, EpilogueTile>());
using CopyOpS2R = decltype(detail::sm100_get_smem_load_op<
Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>());
};
// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node
template<
typename EpilogueDescriptor,
typename StrideOrLayoutTag,
typename ElementAux
>
struct Sm100AuxStoreDescriptor {
constexpr static int Stages = EpilogueDescriptor::StagesD;
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
using Element = ElementAux;
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<
Stride, ElementAux, EpilogueTile>());
using CopyOpR2S = decltype(detail::sm100_get_smem_store_op<
Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>());
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////
@ -1304,17 +1461,7 @@ private:
}
}
using CtaTileShape_MNK = decltype(cta_tile_shape());
static constexpr auto
tmem_warps() {
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
return Shape<_2,_2>{};
}
else {
return Shape<_4,_1>{};
}
}
using TmemWarpShape_MN = decltype(tmem_warps());
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
static constexpr auto
epilogue_tile() {
@ -1338,20 +1485,15 @@ private:
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{}));
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
FusionOp::IsBlockScaleSupported
>());
static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup;
static constexpr auto
dispatch_policy() {
if constexpr (std::is_base_of_v<PtrArrayNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
std::is_base_of_v<PtrArrayNoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
return Sm100PtrArrayNoSmemWarpSpecialized{};
}
else {
return Sm100NoSmemWarpSpecialized{};
}
}
using DispatchPolicy = decltype(dispatch_policy());
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{})));
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
using DispatchPolicy = decltype(detail::sm100_dense_dispatch_policy<EpilogueScheduleType, ElementC_, ElementD, EpiTiles, FragmentSize>());
static constexpr auto
fusion_callbacks() {

View File

@ -507,8 +507,7 @@ public:
int thread_idx,
TensorStorage& shared_tensors,
TensorMapC const& load_tensormap,
int subtile_idx=-1,
bool wait_until_load_finishes = false) {
int subtile_idx=-1) {
using namespace cute;
// Indexing variables
@ -595,12 +594,6 @@ public:
// Post-loop fusion callback entry point
pld_callbacks.end();
if (wait_until_load_finishes && did_load) {
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
}
return load_pipe_producer_state;
}

View File

@ -0,0 +1,274 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <
int CapacityBytes,
class ElementA,
class ElementB,
class TileShapeMNK,
class TileShapeSFA,
class TileShapeSFB,
int stages
>
constexpr int
sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCount<stages> stage_count) {
return stages;
}
template <
int CapacityBytes,
class ElementA,
class ElementB,
class TileShapeMNK,
class TileShapeSFA,
class TileShapeSFB,
int carveout_bytes
>
constexpr int
sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAutoCarveout<carveout_bytes> stage_count) {
// For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t
// Each stage include (CollectiveMma::SharedStorage)
// 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage)
// 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage)
// 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed)
constexpr auto mainloop_pipeline_bytes =
sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage) +
sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{}));
constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{}));
constexpr int stage_bytes =
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
static_cast<int>(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes);
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ElementPairA,
class GmemLayoutATag,
int AlignmentA,
class ElementPairB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class BuilderScheduleTag
>
struct CollectiveBuilder<
arch::Sm100,
arch::OpClassBlockScaledTensorOp,
ElementPairA,
GmemLayoutATag,
AlignmentA,
ElementPairB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape (_1, _1, _1)
StageCountType,
BuilderScheduleTag,
cute::enable_if_t<cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, BuilderScheduleTag> >
>
{
using ElementSFA = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairA>::sf_type;
using ElementSFB = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairB>::sf_type;
using ElementA = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairA>::data_type;
using ElementB = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairB>::data_type;
using ElementSF = ElementSFA;
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
static_assert(detail::blockscaled::check_input_datatypes<BuilderScheduleTag, ElementPairA, ElementPairB, UmmaMajorA, UmmaMajorB>(), "Incorrect input types");
static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm<TileShape_MNK, ClusterShape_MNK, BuilderScheduleTag>();
static constexpr auto Instr = detail::blockscaled::select_instr<ElementPairA, ElementPairB, ElementAccumulator, UmmaMajorA, UmmaMajorB, BuilderScheduleTag>();
using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma<ElementPairA, ElementPairB, ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
UmmaMajorA, UmmaMajorB, Instr, BuilderScheduleTag, is_2sm>::type;
static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8;
static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A<GmemLayoutATag>() && cutlass::gemm::detail::is_k_major_B<GmemLayoutBTag>()), "Only MMA.MXF8F6F4 supports non-K major inputs");
// Data type used by MMA instruction
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA, UseMxf8f6f4>());
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB, UseMxf8f6f4>());
static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
TileShape_MNK, ClusterShape_MNK,
GmemLayoutATag, GmemLayoutBTag, false /*is_sparse*/, is_2sm>(),
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize;
using ElementAMma_SmemAllocType = cute::conditional_t<UseMxf8f6f4, uint8_t, ElementAMma>;
using ElementBMma_SmemAllocType = cute::conditional_t<UseMxf8f6f4, uint8_t, ElementBMma>;
// using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
// ElementAMma, ElementBMma, ElementAccumulator,
// decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
// UmmaMajorA, UmmaMajorB, BuilderScheduleTag>());
using AtomThrID = typename TiledMma::AtomThrID;
using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<SFVectorSize>;
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
// Assigning 4 warps for mainloop load of B
static constexpr int NumLoadThreadsCpAsync = 128;
using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
ClusterShape_MNK{}, AtomThrID{}));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>());
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(cutlass::sizeof_bits<ElementB>::value) * AlignmentB / 8>;
using GmemCopyAtomB = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<AlignmentTypeB>, ElementB>;
using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy<
GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t<GmemLayoutBTag>,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>());
using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
ClusterShape_MNK{}, AtomThrID{}));
using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB(
ClusterShape_MNK{}, AtomThrID{}));
using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{}));
using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{}));
using SmemLayoutAtomSFA = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{}));
using SmemLayoutAtomSFB = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{}));
using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{}));
using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{}));
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
using InternalStrideA = cute::remove_pointer_t<StrideA>;
using InternalStrideB = cute::remove_pointer_t<StrideB>;
using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA());
using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB());
using LayoutSFA = cute::conditional_t<cute::is_same_v<InternalStrideA, StrideA>, InternalLayoutSFA, InternalLayoutSFA *>;
using LayoutSFB = cute::conditional_t<cute::is_same_v<InternalStrideB, StrideB>, InternalLayoutSFB, InternalLayoutSFB *>;
using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{}));
using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{}));
static constexpr uint32_t AccumulatorPipelineStageCount = 2;
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1;
// AccumulatorPipeline = PipelineUmmaAsync
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
// CLCPipeline = PipelineCLCFetchAsync
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
static constexpr auto KernelSmemCarveout = static_cast<int>( AccumulatorPipelineStorage +
CLCPipelineStorage +
CLCResponseStorage);
// Reduce SMEM capacity available for buffers considering barrier allocations.
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
using SmemTileShape = cute::Shape<SmemShapeA_M, BlockTileB_N, SmemShapeA_K>;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync<
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB.");
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled<
PipelineStages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape_MNK>,
TileShape_MNK,
cute::tuple<ElementA, ElementSF>,
StridePairA,
cute::tuple<ElementB, ElementSF>,
StridePairB,
TiledMma,
GmemTiledCopyPairA,
SmemLayoutAtomsA,
void,
cute::identity,
GmemTiledCopyPairB,
SmemLayoutAtomsB,
void,
cute::identity
>;
};
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -120,6 +120,7 @@ struct CollectiveBuilder<
BuilderScheduleTag,
cute::enable_if_t<
// Blockscaled Gemm
(not cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, BuilderScheduleTag>) &&
(cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag> ||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
&&

View File

@ -0,0 +1,171 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class BuilderScheduleTag
>
struct CollectiveBuilder<
arch::Sm100,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
ClusterShape_MNK, // Static cluster shape (_1, _1, _1)
StageCountType,
BuilderScheduleTag,
cute::enable_if_t<cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, BuilderScheduleTag> >
>
{
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
// Data type used by MMA instruction
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA>());
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB>());
using ElementAMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using ElementBMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
ElementAMma, ElementBMma, ElementAccumulator,
decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
UmmaMajorA, UmmaMajorB, BuilderScheduleTag>());
using AtomThrID = typename TiledMma::AtomThrID;
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
cute::size<2>(TileShape_MNK{}))));
// Assigning 4 warps for mainloop load of B
static constexpr int NumLoadThreadsCpAsync = 128;
using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{}));
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
ClusterShape_MNK{}, AtomThrID{}));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>());
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * AlignmentB>;
using GmemCopyAtomB = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<AlignmentTypeB>, ElementB>;
using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy<
GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t<GmemLayoutBTag>,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>());
static constexpr uint32_t AccumulatorPipelineStageCount = 2;
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1;
// AccumulatorPipeline = PipelineUmmaAsync
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
// CLCPipeline = PipelineCLCFetchAsync
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
// CLC (scheduler) response
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
static constexpr auto KernelSmemCarveout = static_cast<int>( AccumulatorPipelineStorage +
CLCPipelineStorage +
CLCResponseStorage);
// Reduce SMEM capacity available for buffers considering barrier allocations.
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
using SmemTileShape = cute::Shape<SmemShapeA_M, BlockTileB_N, SmemShapeA_K>;
using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage;
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override<
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{});
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
PipelineStages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape_MNK>,
TileShape_MNK,
ElementA,
cutlass::gemm::TagToStrideA_t<GmemLayoutATag>,
ElementB,
cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
void,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
void,
cute::identity
>;
};
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -184,6 +184,7 @@ struct CollectiveBuilder<
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
// Dense Gemm / PtrArrayDenseGemm
(
(not cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, BuilderScheduleTag>) &&
(not cute::is_same_v<KernelWarpSpecialized1SmSm100, BuilderScheduleTag>) &&
(cute::is_base_of_v<KernelScheduleSm100DenseGemm, BuilderScheduleTag> ||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)) &&

View File

@ -502,6 +502,7 @@ check_input_datatypes() {
|| (cute::is_same_v<BuilderScheduleTag, KernelScheduleBlockScaledGemmSm100>)
|| (cute::is_same_v<BuilderScheduleTag, KernelTmaWarpSpecialized1SmBlockScaledSm100>)
|| (cute::is_same_v<BuilderScheduleTag, KernelTmaWarpSpecialized2SmBlockScaledSm100>)
|| (cute::is_same_v<BuilderScheduleTag, KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100>)
// SM100 BS ptr_array
|| (cute::is_same_v<BuilderScheduleTag, KernelSchedulePtrArrayBlockScaledGemmSm100>)
|| (cute::is_same_v<BuilderScheduleTag, KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100>)
@ -578,6 +579,8 @@ check_input_datatypes() {
((SfVectorSizeA == 32 && cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag>)
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>)
|| (SfVectorSizeA == 64 && cute::is_base_of_v<KernelScheduleBlockScaledSparseGemmSm100, BuilderScheduleTag>)

View File

@ -1069,10 +1069,10 @@ struct CollectiveBuilder<
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum>) and
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8Blockwise> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise> or
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8Blockwise> or
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise>) and
not detail::is_use_rmem_A<ElementA, GmemLayoutPairA, ElementB, GmemLayoutPairB>()
>
> {
@ -1105,7 +1105,7 @@ struct CollectiveBuilder<
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelScheduleType>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now.");
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 Blockwise (Software) Scaling is only compatible with FP8 inputs version right now.");
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
@ -1146,8 +1146,8 @@ struct CollectiveBuilder<
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

View File

@ -49,6 +49,8 @@
#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl"
#include "cutlass/gemm/collective/builders/sm120_mma_builder.inl"
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"

View File

@ -65,6 +65,8 @@
#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp"
#include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
#include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm120_mma_tma.hpp"

View File

@ -0,0 +1,758 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/detail/cluster.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/trace.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/arch/memory.h"
#include "cute/algorithm/functional.hpp"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
template <
int Stages,
int SchedulerPipelineStageCount,
int AccumulatorPipelineStageCount,
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
using TiledMma = TiledMma_;
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
// Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received
static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA");
static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1");
static_assert(size(typename TiledMma::AtomThrID{}) == 1);
using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
Stages,
SchedulerPipelineStageCount,
AccumulatorPipelineStageCount,
ClusterShape>;
// TileShape refers to MmaTileShape to adapt for runtime cluster
using TileShape = TileShape_;
CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})),
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
// Define A and B block shapes
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
// using LoadShapeA_MK = decltype(select<0,2>(TileShape{}));
using LoadShapeB_NK = decltype(select<1,2>(TileShape{}));
// CtaShape_MNK is queried from collective in all kernel layers
using CtaShape_MNK = TileShape;
using ElementA = ElementA_;
using ElementAMma = typename TiledMma::ValTypeA;
using StrideA = StrideA_;
using ElementB = ElementB_;
using ElementBMma = typename TiledMma::ValTypeB;
using StrideB = StrideB_;
static constexpr bool IsRuntimeDataTypeA = cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float8_t>;
static constexpr bool IsRuntimeDataTypeB = cute::is_same_v<ElementB, cutlass::type_erased_dynamic_float8_t>;
static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB,
"ElementA and ElementB should be both runtime or both static.");
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync<DispatchPolicy::Stages, ClusterShape, AtomThrShapeMNK>;
using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState;
using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync<DispatchPolicy::Stages, AtomThrShapeMNK>;
using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState;
// static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count");
static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{});
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
// (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE)
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomA{},
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomB{},
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
using LoadSmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
append(LoadShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
using TmaInternalElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, cutlass::tfloat32_t, ElementAMma>;
using SmemAllocTypeA = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
using SmemAllocTypeB = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
using BitTypeElementA = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
using BitTypeElementB = cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>;
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
using RuntimeDataTypeA = cute::conditional_t<IsRuntimeDataTypeA, cute::UMMA::MXF8F6F4Format, void*>;
using RuntimeDataTypeB = cute::conditional_t<IsRuntimeDataTypeB, cute::UMMA::MXF8F6F4Format, void*>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128, _0> {
cute::array_aligned<SmemAllocTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::array_aligned<SmemAllocTypeB, cute::cosize_v<LoadSmemLayoutB>> smem_B;
} tensors;
using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage;
using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage;
struct PipelineStorage : cute::aligned_struct<16, _0> {
alignas(16) PipelineStorageTMA tma;
alignas(16) PipelineStorageCpAsync cpasync;
} pipelines;
};
// Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them.
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
static constexpr uint32_t TmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>);
template <class AccTensor>
struct TmemStorage {
AccTensor accumulators;
};
// Host side kernel arguments
struct Arguments {
ArrayElementA const* ptr_A{nullptr};
StrideA dA{};
ArrayElementB const* ptr_B{nullptr};
StrideB dB{};
RuntimeDataTypeA runtime_data_type_a{};
RuntimeDataTypeB runtime_data_type_b{};
};
// Device side kernel params
struct Params {
using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}),
make_tile(typename TiledMma::AtomThrID{})));
using TMA_A = decltype(make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
make_tensor(recast_ptr<TmaInternalElementA>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
ClusterLayout_VMNK{})
);
TMA_A tma_load_a;
ArrayElementB const* ptr_B{nullptr};
StrideB dB{};
RuntimeDataTypeA runtime_data_type_a;
RuntimeDataTypeB runtime_data_type_b;
};
CUTLASS_DEVICE
CollectiveMma(Params const& params)
: runtime_data_type_a_(params.runtime_data_type_a)
, runtime_data_type_b_(params.runtime_data_type_b) {
observed_tma_load_a_ = &params.tma_load_a;
}
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
[[maybe_unused]] void* workspace,
cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
auto ptr_A = recast_ptr<TmaInternalElementA>(args.ptr_A);
auto ptr_B = recast_ptr<ElementBMma>(args.ptr_B);
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{}));
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
GmemTiledCopyA{},
tensor_a,
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
TileShape{},
TiledMma{},
cluster_layout_vmnk);
return {
tma_load_a,
args.ptr_B,
args.dB,
args.runtime_data_type_a,
args.runtime_data_type_b
};
}
template <class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits<ElementA>::value;
bool implementable = true;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
implementable = implementable && cutlass::detail::check_alignment<GmemTiledCopyB::NumValSrc>(cute::make_shape(N,K,L), StrideB{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n");
}
return implementable;
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE void
prefetch_tma_descriptors() {
cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor());
}
/// Construct A Single Stage's Accumulator Shape
CUTLASS_DEVICE static
auto
partition_accumulator_shape() {
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
return acc_shape;
}
template <class TmemStorage>
CUTLASS_DEVICE static
auto
slice_accumulator(TmemStorage tmem_storage, int stage) {
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
}
template <class EpilogueTile, bool IsOverlappingAccum = false>
CUTLASS_DEVICE static
auto
init_tmem_tensors(EpilogueTile epi_tile) {
TiledMma tiled_mma;
auto acc_shape = partition_accumulator_shape();
// ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue.
Tensor accumulators = cutlass::detail::make_sm100_accumulator<AccumulatorPipelineStageCount, IsOverlappingAccum>(
tiled_mma, acc_shape, EpilogueTile{});
TmemStorage<decltype(accumulators)> tmem_storage;
tmem_storage.accumulators = accumulators;
return tmem_storage;
}
template <class TmemStorage>
CUTLASS_DEVICE static
void
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
tmem_storage.accumulators.data() = tmem_base_addr;
}
/// Set up the data needed by this collective for load.
/// Return tuple element contain
/// gA_mkl - The tiled tensor for input A
/// gB_nkl - The tiled tensor for input B
/// tAsA - partitioned smem tensor for A
/// tBsB - partitioned smem tensor for B
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto
load_init_tma(
ProblemShape_MNKL const& problem_shape_MNKL,
TensorStorage& shared_tensors) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
// TMA
Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L));
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l)
ThrMMA cta_mma = TiledMma{}.get_slice(0);
Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
// Define the CTA-in-cluster Layout and Coord
Layout cta_layout_mnk = make_layout(ClusterShape{});
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0);
// Project the cta_layout for tma_a along the n-modes
auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl));
return cute::make_tuple(
shape<3>(gA_mkl), // for scheduler
tAgA_mkl, tAsA // for input tensor values
);
}
template <class ProblemShape_MNKL, class TileScheduler>
CUTLASS_DEVICE auto
load_init_cpasync(
ProblemShape_MNKL const& problem_shape_MNKL,
Params const& params,
TensorStorage& shared_tensors,
TileScheduler const& scheduler,
typename TileScheduler::WorkTileInfo const& work_tile_info) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M,N,K,L] = problem_shape_MNKL;
// Represent the full tensors
Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l)
// Partition for cpasync
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
// Build the coordinate tensors with the same shape as input matrices
Tensor cB_nk = make_identity_tensor(make_shape(N,K));
// Slice the coordinate tensors in the same way as A/B tensor partitioning
Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{});
GmemTiledCopyB gmem_to_smem_b_tiled_copy;
int thread_idx = threadIdx.x % NumLoadThreadsCpAsync;
auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx);
return cute::make_tuple(
gB_nkl, cgB_nk, sB,
gmem_to_smem_b_tiled_copy, thr_copy_b);
}
/// Set up the data needed by this collective for mma compute.
template <class TmemStorage>
CUTLASS_DEVICE auto
mma_init(
Params const& params,
[[maybe_unused]] TmemStorage tmem_storage,
// [[maybe_unused]] cute::tuple<cute::Tensor<FrgEngine, FrgLayout>, cute::Tensor<FrgEngine, FrgLayout>> const& accumulators_pair,
TensorStorage& shared_tensors) const {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
// Allocate "fragments/descriptors" for A and B matrices
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
TiledMma tiled_mma;
if constexpr (IsRuntimeDataType) {
// Update instruction descriptor according to runtime argument.
// Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111;
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111;
}
return cute::make_tuple(tiled_mma, tCrA, tCrB);
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <
class KTileCount,
class GTensorPartitionedA,
class STensorA,
class TileCoordMNKL,
class KTileIterator
>
CUTLASS_DEVICE auto
load_tma(
MainloopPipelineTMA mainloop_pipeline,
MainloopPipelineTMAState mainloop_pipe_producer_state,
cute::tuple<KTileCount,
GTensorPartitionedA,
STensorA> const& load_inputs,
TileCoordMNKL const& cta_coord_mnkl,
KTileIterator k_tile_iter, int k_tile_count) {
// Unpack from load_inputs
KTileCount k_tiles = get<0>(load_inputs);
GTensorPartitionedA tAgA_mkl = get<1>(load_inputs);
STensorA tAsA = get<2>(load_inputs);
// slice out the work coord from partitioned tensors
Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
// Issue the Mainloop loads
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// LOCK mainloop_pipe_producer_state for _writing_
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
if (cute::elect_one_sync()) {
copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
}
--k_tile_count;
++k_tile_iter;
}
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
}
template <
// class GTensorB,
// class CTensorB,
// class STensorB,
// class ProblemShape_MNKL,
// class TiledCopyB,
// class ThreadCopyB,
class TileCoordMNKL,
class KTileIterator,
class ProblemShape_MNKL,
class... TParams
>
CUTLASS_DEVICE auto
load_cpasync(
Params const& params,
MainloopPipelineCpAsync mainloop_pipeline,
MainloopPipelineCpAsyncState mainloop_pipe_producer_state,
cute::tuple<TParams...> const& load_inputs,
TileCoordMNKL const& cta_coord_mnkl,
KTileIterator k_tile_iter, int k_tile_count,
ProblemShape_MNKL effective_shape
) {
// Unpack from load_inputs
// GTensorB tBgB_nkl = get<0>(load_inputs);
// CTensorB cgB_nk = get<1>(load_inputs);
// STensorB sB = get<2>(load_inputs);
// ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs);
// TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs);
// ThreadCopyB thr_copy_b = get<5>(load_inputs);
auto [
tBgB_nkl, cgB_nk, sB,
// problem_shape_MNKL,
gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs;
auto [M,N,K,L] = effective_shape;
// Slice out the work coord from partitioned tensors
Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
// Repeat slicing out coordinate tensor exactly the same as input tensor does
Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _);
auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative
Tensor gB = gB_in;
Tensor cB = cgB_nk_in;
auto tBgB = thr_copy_b.partition_S(gB);
auto tBsB = thr_copy_b.partition_D(sB);
// Allocate predicate tensors for n
Tensor tBpB = make_tensor<bool>(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{});
Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in);
Tensor tBcB = thr_copy_b.partition_S(cB);
// Copy gmem to smem for *k_tile_iter, predicating for k residue
Tensor tBgBk = tBgB(_,_,_,*k_tile_iter);
// Repeating on predicators with the same operations on tBgB
Tensor tBcBk = tBcB(_,_,_,*k_tile_iter);
// Set predicates for n bounds
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<0>(tBpB); ++n) {
tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N
}
// we will process the last tile after the mainloop
if (k_residue != 0) {
--k_tile_count;
}
// Issue the Mainloop loads
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive);
--k_tile_count;
++k_tile_iter;
++mainloop_pipe_producer_state;
}
// last tile with predication on k to account for residue
// For performance consideration,
// this predicated block for K-tail is only activated when there is k-residue
if (k_residue != 0) {
// LOCK mainloop_pipe_producer_state for _writing_
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(tBsB); ++k) {
if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K
copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage));
}
else {
clear(tBsB(_,_,k,write_stage));
}
}
++k_tile_iter;
--k_tile_count;
// UNLOCK mainloop_pipe_producer_state
mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive);
// Advance mainloop_pipe_producer_state
++mainloop_pipe_producer_state;
}
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
}
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
CUTLASS_DEVICE void
load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) {
// Issue the epilogue waits
// This helps avoid early exit of ctas in Cluster
// Waits for all stages to either be released (all
// Consumer UNLOCKs), or if the stage was never used
// then would just be acquired since the phase was
// still inverted from make_producer_start_state
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
}
CUTLASS_DEVICE void
load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) {
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <
class AccumulatorPipeline,
class FrgEngine, class FrgLayout,
class FragmentA, class FragmentB,
class CtaTileCoord
>
CUTLASS_DEVICE auto
mma(cute::tuple<MainloopPipelineTMA,
MainloopPipelineCpAsync,
AccumulatorPipeline> pipelines,
cute::tuple<MainloopPipelineTMAState,
MainloopPipelineCpAsyncState,
typename AccumulatorPipeline::PipelineState> pipeline_states,
cute::tuple<cute::Tensor<FrgEngine, FrgLayout>> const& accumulators_pair,
cute::tuple<TiledMma, FragmentA, FragmentB> const& mma_inputs,
CtaTileCoord cta_tile_coord,
int k_tile_count
) {
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
auto accumulators = get<0>(accumulators_pair);
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines;
auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states;
//
// PIPELINED MAIN LOOP
//
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
// Wait for tmem accumulator buffer to become empty with a flipped phase
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state);
mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state);
int read_stage_tma = mainloop_pipe_tma_consumer_state.index();
int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index();
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state);
mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state);
--k_tile_count;
++mainloop_pipe_tma_consumer_state;
++mainloop_pipe_cpasync_consumer_state;
}
return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state);
}
protected:
typename Params::TMA_A const* observed_tma_load_a_{nullptr};
RuntimeDataTypeA runtime_data_type_a_{};
RuntimeDataTypeB runtime_data_type_b_{};
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -248,7 +248,7 @@ public:
using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync<
DispatchPolicy::Load2TransformPipelineStageCount,
ClusterShape,
ClusterShape,
AtomThrShapeMNK>;
using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState;
@ -316,7 +316,7 @@ public:
using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape(
SmemLayoutAtomACompute{},
append(CtaShapeA_MK{}, Int<DispatchPolicy::Load2TransformPipelineStageCount>{}),
append(CtaShapeA_MK{}, Int<DispatchPolicy::Transform2MmaPipelineStageCount>{}),
(cute::conditional_t<cutlass::gemm::detail::is_mn_major<StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})));
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
@ -385,7 +385,7 @@ public:
struct TensorStorageUntransformed {
alignas(512) cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> smem_B;
alignas(1024) cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
};

View File

@ -73,7 +73,7 @@ template <
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<Stages, ClusterShape, KernelSchedule>,
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<Stages, ClusterShape, KernelSchedule>,
TileShape_,
ElementA_,
StridePairA_,
@ -92,7 +92,7 @@ struct CollectiveMma<
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<Stages, ClusterShape, KernelSchedule>;
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = cute::tuple_element_t<0,StridePairA_>;
@ -382,8 +382,6 @@ struct CollectiveMma<
auto [M,N,K,L] = problem_shape_MNKL;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
// We expect full tiles in K
implementable = implementable && K % size<2>(TileShape{}) == 0;
}
}
@ -824,16 +822,13 @@ struct CollectiveMma<
// Prologue GMMAs
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
// fence_operand();
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
warpgroup_fence_operand(accumulation());
{
if (k_tile_count > 0) {
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
// Load per block scale values from shared memory to registers
@ -977,7 +972,7 @@ struct CollectiveMma<
++smem_pipe_release;
}
if (k_tile_count) {
if (k_tile_count > 0) {
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
@ -1072,9 +1067,11 @@ struct CollectiveMma<
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
// The pipeline is not released in the first iteration
smem_pipe_release.advance(k_tile_count - 1);
pipeline.consumer_release(smem_pipe_release);
if (k_tile_count > 0) {
// The pipeline is not released in the first iteration
smem_pipe_release.advance(k_tile_count - 1);
pipeline.consumer_release(smem_pipe_release);
}
}
//

View File

@ -73,7 +73,7 @@ template <
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule>,
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<Stages, ClusterShape, KernelSchedule>,
TileShape_,
ElementA_,
StridePairA_,
@ -91,7 +91,7 @@ struct CollectiveMma<
//
// Type Aliases
//
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule>;
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = cute::tuple_element_t<0,StridePairA_>;
@ -391,12 +391,6 @@ struct CollectiveMma<
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n");
}
// We expect full tiles in K
if (K % size<2>(TileShape{}) != 0) {
implementable = false;
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n");
}
return implementable;
}

View File

@ -127,10 +127,15 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpong { };
// FP8 related policies (including Blocked Scaled Accumulation)
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { };
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperativeFP8Blockwise: KernelTmaWarpSpecializedCooperative { };
struct KernelTmaWarpSpecializedPingpongFP8Blockwise: KernelTmaWarpSpecializedPingpong { };
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise: KernelPtrArrayTmaWarpSpecializedCooperative { };
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise: KernelPtrArrayTmaWarpSpecializedPingpong { };
using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelTmaWarpSpecializedPingpongFP8Blockwise;
using KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
// Policies to opt into mixed type GEMMs
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
@ -319,17 +324,17 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
// For FP8 kernels with Block Scaling
// For FP8 kernels with Blockwise (Software) Scaling
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8Blockwise
>
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
struct MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> ||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum>,
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8Blockwise> ||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8Blockwise>,
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
};
@ -411,15 +416,15 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise
>
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
static_assert(
cute::is_any_of_v<
KernelSchedule,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum,
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise,
KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise
>,
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
};
@ -440,6 +445,15 @@ struct KernelWarpSpecializedSm100 final {
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
};
template<
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_
>
struct KernelMixedTmaCpAsyncWarpSpecializedSm100 final {
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
};
template<
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_
@ -653,7 +667,7 @@ template<
class KernelSchedule
>
struct HasAuxiliaryLoad<
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<
Stages,
ClusterShape,
KernelSchedule
@ -666,7 +680,7 @@ template<
class KernelSchedule
>
struct HasAuxiliaryLoad<
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<
Stages,
ClusterShape,
KernelSchedule
@ -700,6 +714,7 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy
struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder
struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder
struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA
struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {};
///////////////////////////////////////////////////////////////////////////////////////////////////////
// SM100 Ptr-Array Dense GEMM Dispatch Policies
@ -795,6 +810,8 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1
struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { };
struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { };
struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {};
///////////////////////////////////////////////////////////////////////////////////////////////////////
// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies
///////////////////////////////////////////////////////////////////////////////////////////////////////
@ -950,6 +967,34 @@ struct MainloopSm100UmmaCpAsyncWarpSpecialized {
using Schedule = KernelWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
};
template<
int Stages_,
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_,
class ClusterShape_ = Shape<_1,_1,_1>
>
struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm100;
using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
constexpr static bool IsOverlappingAccum = false;
};
template<
int Stages_,
int SchedulerPipelineStageCount_,
int AccumulatorPipelineStageCount_,
class ClusterShape_ = Shape<_1,_1,_1>
>
struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm100;
using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
constexpr static bool IsOverlappingAccum = false;
};
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,

View File

@ -79,6 +79,16 @@ struct GroupProblemShape {
}
};
template <class ProblemShape_, class MaxProblemShape_>
struct MoEProblemShape {
using UnderlyingProblemShape = ProblemShape_;
using MaxProblemShape = MaxProblemShape_;
UnderlyingProblemShape problem_shape;
MaxProblemShape max_problem_shape;
};
template <class ProblemShape_>
class ArrayProblemShape {
public:
@ -120,4 +130,14 @@ private:
UnderlyingProblemShape problem_shape_{};
};
namespace detail {
template<class T>
struct is_moe_problem_shape : cute::false_type {};
template<class T, class U>
struct is_moe_problem_shape<cutlass::gemm::MoEProblemShape<T,U>> : cute::true_type {};
}
} // namespace cutlass::gemm

View File

@ -73,6 +73,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp"
#include "cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp"
#include "cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp"

File diff suppressed because it is too large Load Diff

View File

@ -240,6 +240,27 @@ public:
void
fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { }
template <
bool IsComplex,
class TiledMma,
class AccEngine,
class AccLayout,
class AccumulatorPipeline,
class AccumulatorPipelineState,
class CopyOpT2R
>
CUTLASS_DEVICE
AccumulatorPipelineState
fixup(
TiledMma const& ,
WorkTileInfo const&,
cute::Tensor<AccEngine, AccLayout>&,
AccumulatorPipeline,
AccumulatorPipelineState acc_pipe_consumer_state,
CopyOpT2R) const {
return acc_pipe_consumer_state;
}
template <class ProblemShape, class ElementAccumulator>
static size_t
get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) {

View File

@ -991,7 +991,7 @@ public:
mainloop_sf_pipeline,
mainloop_sf_pipe_producer_state,
load_inputs,
cta_coord_mnkl,
cta_coord_mnk,
k_tile_iter_next, k_tile_count - k_tile_prologue,
false, /* did_batch_change - prologue loads handle tensormap acquire */
enable_prefetch ? k_tile_count - k_tile_prologue : 0

View File

@ -831,8 +831,6 @@ public:
collective_epilogue.template tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
}
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
@ -843,8 +841,7 @@ public:
lane_idx,
shared_storage.tensors.epilogue,
epi_load_tensormap,
work_tile_info.reduction_subtile_idx(),
wait
work_tile_info.reduction_subtile_idx()
);
}

View File

@ -869,8 +869,6 @@ public:
collective_epilogue.template tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
}
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline,
epi_load_pipe_producer_state,
@ -881,8 +879,7 @@ public:
lane_idx,
shared_storage.tensors.epilogue,
epi_load_tensormap,
work_tile_info.reduction_subtile_idx(),
wait
work_tile_info.reduction_subtile_idx()
);
}