v4.2 tag release. (#2638)
This commit is contained in:
@ -539,7 +539,8 @@ make_cotiled_copy(Copy_Atom<Args...> const& copy_atom,
|
||||
auto layout_tv_data = composition(inv_data_layout, atom_tv_layout);
|
||||
|
||||
// Check validity
|
||||
CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)),
|
||||
// Append 1:0 to data_layout so that OOB coordinates get the stride-0
|
||||
CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)),
|
||||
"The memory pointed to by AtomTVLayout does not exist in the DataLayout.");
|
||||
//
|
||||
// Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them
|
||||
|
||||
@ -705,7 +705,7 @@ public:
|
||||
auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info);
|
||||
auto&& scales = cute::get<1>(partitioned_transform_extra_info);
|
||||
using ScaleType = decltype(scales);
|
||||
auto tSrS = make_tensor(static_cast<ScaleType&&>(scales).data(), scales.layout());
|
||||
auto tSrS = make_tensor(scales.data(), scales.layout());
|
||||
auto tSsS = cute::get<2>(partitioned_transform_extra_info);
|
||||
copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS);
|
||||
|
||||
@ -714,7 +714,7 @@ public:
|
||||
} else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
||||
auto&& zeros = cute::get<3>(partitioned_transform_extra_info);
|
||||
using ZeroType = decltype(zeros);
|
||||
auto tZrZ = make_tensor(static_cast<ZeroType&&>(zeros).data(), zeros.layout());
|
||||
auto tZrZ = make_tensor(zeros.data(), zeros.layout());
|
||||
auto tZsZ = cute::get<4>(partitioned_transform_extra_info);
|
||||
copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ);
|
||||
|
||||
@ -1002,7 +1002,6 @@ public:
|
||||
auto src_arr = recast<SrcArray>(src);
|
||||
auto dst_arr = recast<DstArray>(dst);
|
||||
|
||||
Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, pack));
|
||||
Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack));
|
||||
|
||||
cute::transform(src_arr, dst_arr, Converter::convert);
|
||||
@ -1019,7 +1018,6 @@ public:
|
||||
auto scale_arr = recast<ScaleArray>(filter_zeros(scales));
|
||||
|
||||
if constexpr (is_same_v<DstType, cutlass::bfloat16_t>){
|
||||
Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack));
|
||||
Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack));
|
||||
|
||||
for (int i = 0; i < size<1>(dst_vm); ++i){
|
||||
@ -1194,13 +1192,7 @@ public:
|
||||
Tensor tCsS = cta_mma.partition_A(sS);
|
||||
Tensor tSsS = smem_thr_copy_S.partition_S(tCsS);
|
||||
Tensor tSrS = make_tensor<ElementScale>(tSsS(_,_,_,_,0).shape());
|
||||
#if 0
|
||||
if(cute::thread(128, 0)){
|
||||
print("sS: ");print(sS);print("\n");
|
||||
print("tSsS: ");print(tSsS);print("\n");
|
||||
print("tSrS: ");print(tSrS);print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
||||
return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS);
|
||||
}
|
||||
@ -1209,16 +1201,6 @@ public:
|
||||
Tensor tCsZ = cta_mma.partition_A(sZ);
|
||||
Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ);
|
||||
Tensor tZrZ = make_tensor<ElementZero>(tZsZ(_,_,_,_,0).shape());
|
||||
#if 0
|
||||
if(cute::thread(128, 0)){
|
||||
print("sS: ");print(sS);print("\n");
|
||||
print("tSsS: ");print(tSsS);print("\n");
|
||||
print("tSrS: ");print(tSrS);print("\n");
|
||||
print("sZ: ");print(sZ);print("\n");
|
||||
print("tZsZ: ");print(tZsZ);print("\n");
|
||||
print("tZrZ: ");print(tZrZ);print("\n");
|
||||
}
|
||||
#endif
|
||||
return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ);
|
||||
}
|
||||
else {
|
||||
|
||||
@ -582,7 +582,13 @@ sm100_sparse_get_tma_dispatch_policy() {
|
||||
* Selected op also maximizes the TMEM_LOAD shape in order to minimize TMEM_LOADs issued,
|
||||
* subject to the constraint of the provided per-warp tmem subpartition shape
|
||||
**/
|
||||
template<class GmemStrideTypeD, class ElementAccumulator, class ElementD, class TmemShape_MN, class FusionOp>
|
||||
template<
|
||||
class GmemStrideTypeD,
|
||||
class ElementAccumulator,
|
||||
class ElementD,
|
||||
class TmemShape_MN,
|
||||
bool IsBlockScaleSupported
|
||||
>
|
||||
constexpr auto
|
||||
sm100_get_tmem_load_op() {
|
||||
using namespace cute;
|
||||
@ -958,6 +964,172 @@ struct CallbacksBuilder<
|
||||
>;
|
||||
};
|
||||
|
||||
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
|
||||
template<
|
||||
class OpClass,
|
||||
class CtaTileShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class TmemWarpShape_MN,
|
||||
class ElementC_,
|
||||
class GmemStrideTypeC,
|
||||
class ElementD,
|
||||
class GmemStrideTypeD,
|
||||
bool IsPerColScaleSupported
|
||||
>
|
||||
static constexpr auto
|
||||
sm100_dense_compute_tile_shape_or_override() {
|
||||
using namespace cute;
|
||||
static_assert(!cute::is_same_v<OpClass, arch::OpClassSparseTensorOp> && !cute::is_same_v<OpClass, arch::OpClassBlockScaledSparseTensorOp>);
|
||||
|
||||
constexpr bool DisableSource = cute::is_void_v<ElementC_>;
|
||||
using ElementC = cute::conditional_t<DisableSource, ElementD, ElementC_>;
|
||||
|
||||
if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
|
||||
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
|
||||
size<1>(CtaTileShape_MNK{}) == 256) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int DpFull = 32;
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
// Note:
|
||||
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
|
||||
// This is a general workable epi_tile_N which does not promise best perf.
|
||||
return make_tile(Int<M>{}, Int<128>{});
|
||||
}
|
||||
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
|
||||
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
|
||||
|
||||
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
|
||||
// Epilogues w/o residual load are less sensitive to smem allocation
|
||||
// Target a fixed amount of compute per epilogue iteration
|
||||
if (DisableSource) {
|
||||
if (MaxBits == 4) {
|
||||
// Make epilogue tile larger to reduce the epilogue iterations.
|
||||
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
constexpr int ComputeElts = 8192;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
constexpr int ComputeElts = 4096;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
// Epilogues w/ residual load are more sensitive to smem allocation
|
||||
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
else {
|
||||
if (MaxBits == 32) {
|
||||
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
|
||||
}
|
||||
// Per-column scaling is high register pressure, reduce tile to prevent spills
|
||||
else if (IsPerColScaleSupported) {
|
||||
return 32;
|
||||
}
|
||||
else if (MaxBits == 16) {
|
||||
return (CtaN <= 128) ? 32 : 64;
|
||||
}
|
||||
else {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementC> * WarpN;
|
||||
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementD> * WarpN;
|
||||
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
|
||||
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
|
||||
|
||||
// stride by tmem warp layout and return a by-mode tiler
|
||||
auto tile_m = Layout<Int<M>>{};
|
||||
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
|
||||
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
|
||||
|
||||
return make_tile(tile_m, coalesce(tile_n));
|
||||
}
|
||||
else {
|
||||
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
|
||||
"EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
|
||||
EpilogueTileType epi_tile;
|
||||
constexpr int M = size<0>(shape(epi_tile));
|
||||
constexpr int N = size<1>(shape(epi_tile));
|
||||
static_assert(N % 8 == 0, "Unsupported tile shape");
|
||||
|
||||
return epi_tile;
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
bool Is2SmMma,
|
||||
class MmaTileShape_MNK
|
||||
>
|
||||
static constexpr auto
|
||||
sm100_tmem_warps() {
|
||||
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
|
||||
return Shape<_2,_2>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_4,_1>{};
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
bool Is2SmMma,
|
||||
class MmaTileShape_MNK
|
||||
>
|
||||
static constexpr auto
|
||||
sm100_cta_tile_shape() {
|
||||
if constexpr (Is2SmMma) { // 2x1 threadblock shape
|
||||
auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{};
|
||||
auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode
|
||||
return make_shape(cta_tile_m, mma_tile_n, mma_tile_k);
|
||||
}
|
||||
else { // 1x1 threadblock shape
|
||||
return MmaTileShape_MNK{};
|
||||
}
|
||||
}
|
||||
|
||||
template<
|
||||
class EpilogueScheduleType,
|
||||
class ElementC_,
|
||||
class ElementD,
|
||||
int EpiTiles,
|
||||
int FragmentSize
|
||||
>
|
||||
static constexpr auto
|
||||
sm100_dense_dispatch_policy() {
|
||||
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
|
||||
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
|
||||
// TMA store delay performs worse with residual loads
|
||||
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
|
||||
|
||||
constexpr int StagesD = cute::min(EpiTiles, 2);
|
||||
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
|
||||
: cute::min(EpiTiles, 4);
|
||||
|
||||
if constexpr (is_base_of_v<PtrArrayNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||
is_base_of_v<PtrArrayNoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
|
||||
return Sm100PtrArrayNoSmemWarpSpecialized{};
|
||||
}
|
||||
else if constexpr (is_base_of_v<NoSmemWarpSpecialized1Sm, EpilogueScheduleType> || is_base_of_v<NoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
|
||||
return Sm100NoSmemWarpSpecialized{};
|
||||
}
|
||||
else if constexpr (is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized1Sm> ||
|
||||
is_same_v<EpilogueScheduleType, PtrArrayTmaWarpSpecialized2Sm>) {
|
||||
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
|
||||
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
|
||||
}
|
||||
else {
|
||||
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for building TMA warp-specialized collective epilogues, specialized by
|
||||
// the fusion operation performed and the dispatch policy to use.
|
||||
template <
|
||||
@ -1017,17 +1189,7 @@ private:
|
||||
}
|
||||
}
|
||||
using CtaTileShape_MNK = decltype(cta_tile_shape());
|
||||
|
||||
static constexpr auto
|
||||
tmem_warps() {
|
||||
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
|
||||
return Shape<_2,_2>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_4,_1>{};
|
||||
}
|
||||
}
|
||||
using TmemWarpShape_MN = decltype(tmem_warps());
|
||||
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
|
||||
|
||||
// Attempts to compute a reasonably performant epilogue tile or allows the user to provide one.
|
||||
static constexpr auto
|
||||
@ -1041,84 +1203,10 @@ private:
|
||||
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, Schedule,
|
||||
FusionOp>();
|
||||
}
|
||||
else if constexpr (is_same_v<OpClass, arch::OpClassBlockScaledTensorOp> &&
|
||||
is_same_v<EpilogueTileType, EpilogueTileAuto> &&
|
||||
size<1>(CtaTileShape_MNK{}) == 256) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int DpFull = 32;
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
// Note:
|
||||
// Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile.
|
||||
// This is a general workable epi_tile_N which does not promise best perf.
|
||||
return make_tile(Int<M>{}, Int<128>{});
|
||||
}
|
||||
else if constexpr (is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
||||
constexpr int CtaM = size<0>(CtaTileShape_MNK{});
|
||||
constexpr int CtaN = size<1>(CtaTileShape_MNK{});
|
||||
constexpr int WarpM = size<0>(TmemWarpShape_MN{});
|
||||
constexpr int WarpN = size<1>(TmemWarpShape_MN{});
|
||||
constexpr int MaxBits = cute::max(sizeof_bits_v<ElementC>, sizeof_bits_v<ElementD>);
|
||||
|
||||
constexpr int DpFull = 32; // tmem datapaths in 1 subpartition
|
||||
constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load
|
||||
constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf
|
||||
// Epilogues w/o residual load are less sensitive to smem allocation
|
||||
// Target a fixed amount of compute per epilogue iteration
|
||||
if (DisableSource) {
|
||||
if (MaxBits == 4) {
|
||||
// Make epilogue tile larger to reduce the epilogue iterations.
|
||||
// 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
||||
constexpr int ComputeElts = 8192;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
constexpr int ComputeElts = 4096;
|
||||
return ComputeElts / M;
|
||||
}
|
||||
// Epilogues w/ residual load are more sensitive to smem allocation
|
||||
// Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
||||
else {
|
||||
if (MaxBits == 32) {
|
||||
return (CtaM > 64 && CtaN <= 128) ? 16 : 32;
|
||||
}
|
||||
// Per-column scaling is high register pressure, reduce tile to prevent spills
|
||||
else if (FusionOp::IsPerColScaleSupported) {
|
||||
return 32;
|
||||
}
|
||||
else if (MaxBits == 16) {
|
||||
return (CtaN <= 128) ? 32 : 64;
|
||||
}
|
||||
else {
|
||||
return 64;
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr int N_min_C = (DisableSource || detail::is_m_major<GmemStrideTypeC>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementC> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementC> * WarpN;
|
||||
constexpr int N_min_D = (detail::is_m_major<GmemStrideTypeD>()) ? 8 * WarpN
|
||||
: (sizeof_bits_v<ElementD> == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type
|
||||
: 128 / sizeof_bits_v<ElementD> * WarpN;
|
||||
constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D));
|
||||
static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small");
|
||||
|
||||
// stride by tmem warp layout and return a by-mode tiler
|
||||
auto tile_m = Layout<Int<M>>{};
|
||||
auto tile_n = Layout<Shape <Int<N / WarpN>,Int< WarpN>>,
|
||||
Stride<Int< 1>,Int<CtaN / WarpN>>>{};
|
||||
|
||||
return make_tile(tile_m, coalesce(tile_n));
|
||||
}
|
||||
else {
|
||||
static_assert(cute::is_tuple<EpilogueTileType>::value && not is_layout<EpilogueTileType>::value,
|
||||
"EpilogueTile must be a cute::Tile or cute::Shape");
|
||||
|
||||
EpilogueTileType epi_tile;
|
||||
constexpr int M = size<0>(shape(epi_tile));
|
||||
constexpr int N = size<1>(shape(epi_tile));
|
||||
static_assert(N % 8 == 0, "Unsupported tile shape");
|
||||
|
||||
return epi_tile;
|
||||
return sm100_dense_compute_tile_shape_or_override<
|
||||
OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN,
|
||||
ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp::IsPerColScaleSupported>();
|
||||
}
|
||||
}
|
||||
using EpilogueTile_MN = decltype(epilogue_tile());
|
||||
@ -1129,30 +1217,18 @@ private:
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
|
||||
FusionOp::IsBlockScaleSupported
|
||||
>());
|
||||
|
||||
static constexpr auto
|
||||
dispatch_policy() {
|
||||
// 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation
|
||||
constexpr bool ReuseSmem = sizeof_bits_v<ElementC_> > 8;
|
||||
// TMA store delay performs worse with residual loads
|
||||
constexpr bool DelayTmaStore = is_void_v<ElementC_>;
|
||||
|
||||
constexpr int StagesD = cute::min(EpiTiles, 2);
|
||||
constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1)
|
||||
: cute::min(EpiTiles, 4);
|
||||
|
||||
if constexpr (is_same_v<OpClass, arch::OpClassSparseTensorOp> ||
|
||||
is_same_v<OpClass, arch::OpClassBlockScaledSparseTensorOp>) {
|
||||
return detail::sparse::sm100_sparse_get_tma_dispatch_policy<CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>();
|
||||
}
|
||||
else if constexpr (is_same_v<Schedule, PtrArrayTmaWarpSpecialized1Sm> ||
|
||||
is_same_v<Schedule, PtrArrayTmaWarpSpecialized2Sm>) {
|
||||
constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs
|
||||
return Sm100PtrArrayTmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore_>{};
|
||||
}
|
||||
else {
|
||||
return Sm100TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmem, DelayTmaStore>{};
|
||||
return detail::sm100_dense_dispatch_policy<Schedule, ElementC_, ElementD, EpiTiles, FragmentSize>();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1228,6 +1304,87 @@ public:
|
||||
>;
|
||||
};
|
||||
|
||||
template<
|
||||
class OpClass,
|
||||
class MmaTileShape_MNK,
|
||||
class EpilogueTileType,
|
||||
class ElementAccumulator_,
|
||||
class ElementC,
|
||||
class ElementD,
|
||||
class Schedule,
|
||||
class GmemStrideTypeC,
|
||||
class GmemStrideTypeD,
|
||||
bool IsPerColScaleSupported,
|
||||
bool IsBlockScaleSupported
|
||||
>
|
||||
struct Sm100EpilogueDescriptor {
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
static constexpr bool Is2SmMma = is_base_of_v<TmaWarpSpecialized2Sm, Schedule> || is_base_of_v<NoSmemWarpSpecialized2Sm, Schedule>;
|
||||
using CtaTileShape_MNK = decltype(sm100_cta_tile_shape<Is2SmMma, MmaTileShape_MNK>());
|
||||
using TileShape = CtaTileShape_MNK;
|
||||
|
||||
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
|
||||
|
||||
using EpilogueTile = decltype(
|
||||
sm100_dense_compute_tile_shape_or_override<OpClass, CtaTileShape_MNK, EpilogueTileType,
|
||||
TmemWarpShape_MN, ElementC, GmemStrideTypeC, ElementD, GmemStrideTypeD, IsPerColScaleSupported>()
|
||||
);
|
||||
|
||||
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{})));
|
||||
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
|
||||
static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup;
|
||||
|
||||
using DispatchPolicy = decltype(sm100_dense_dispatch_policy<Schedule, ElementC, ElementD, EpiTiles, FragmentSize>());
|
||||
|
||||
constexpr static int StagesC = DispatchPolicy::StagesC;
|
||||
constexpr static int StagesD = DispatchPolicy::StagesD;
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
|
||||
IsBlockScaleSupported
|
||||
>());
|
||||
};
|
||||
|
||||
// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node
|
||||
template<
|
||||
typename EpilogueDescriptor,
|
||||
typename StrideOrLayoutTag,
|
||||
typename ElementAux
|
||||
>
|
||||
struct Sm100AuxLoadDescriptor {
|
||||
constexpr static int Stages = EpilogueDescriptor::StagesC;
|
||||
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
|
||||
using Element = ElementAux;
|
||||
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
|
||||
|
||||
using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<
|
||||
Stride, ElementAux, EpilogueTile>());
|
||||
|
||||
using CopyOpS2R = decltype(detail::sm100_get_smem_load_op<
|
||||
Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>());
|
||||
};
|
||||
|
||||
// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node
|
||||
template<
|
||||
typename EpilogueDescriptor,
|
||||
typename StrideOrLayoutTag,
|
||||
typename ElementAux
|
||||
>
|
||||
struct Sm100AuxStoreDescriptor {
|
||||
constexpr static int Stages = EpilogueDescriptor::StagesD;
|
||||
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
|
||||
using Element = ElementAux;
|
||||
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
|
||||
|
||||
using SmemLayoutAtom = decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom<
|
||||
Stride, ElementAux, EpilogueTile>());
|
||||
|
||||
using CopyOpR2S = decltype(detail::sm100_get_smem_store_op<
|
||||
Stride, ElementAux, typename EpilogueDescriptor::ElementAccumulator, typename EpilogueDescriptor::AccLoadOp>());
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -1304,17 +1461,7 @@ private:
|
||||
}
|
||||
}
|
||||
using CtaTileShape_MNK = decltype(cta_tile_shape());
|
||||
|
||||
static constexpr auto
|
||||
tmem_warps() {
|
||||
if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) {
|
||||
return Shape<_2,_2>{};
|
||||
}
|
||||
else {
|
||||
return Shape<_4,_1>{};
|
||||
}
|
||||
}
|
||||
using TmemWarpShape_MN = decltype(tmem_warps());
|
||||
using TmemWarpShape_MN = decltype(detail::sm100_tmem_warps<Is2SmMma, MmaTileShape_MNK>());
|
||||
|
||||
static constexpr auto
|
||||
epilogue_tile() {
|
||||
@ -1338,20 +1485,15 @@ private:
|
||||
|
||||
using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{}));
|
||||
using AccLoadOp = decltype(detail::sm100_get_tmem_load_op<
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>());
|
||||
GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN,
|
||||
FusionOp::IsBlockScaleSupported
|
||||
>());
|
||||
static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup;
|
||||
|
||||
static constexpr auto
|
||||
dispatch_policy() {
|
||||
if constexpr (std::is_base_of_v<PtrArrayNoSmemWarpSpecialized1Sm, EpilogueScheduleType> ||
|
||||
std::is_base_of_v<PtrArrayNoSmemWarpSpecialized2Sm, EpilogueScheduleType>) {
|
||||
return Sm100PtrArrayNoSmemWarpSpecialized{};
|
||||
}
|
||||
else {
|
||||
return Sm100NoSmemWarpSpecialized{};
|
||||
}
|
||||
}
|
||||
using DispatchPolicy = decltype(dispatch_policy());
|
||||
using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile{})));
|
||||
static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{}));
|
||||
|
||||
using DispatchPolicy = decltype(detail::sm100_dense_dispatch_policy<EpilogueScheduleType, ElementC_, ElementD, EpiTiles, FragmentSize>());
|
||||
|
||||
static constexpr auto
|
||||
fusion_callbacks() {
|
||||
|
||||
@ -507,8 +507,7 @@ public:
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
TensorMapC const& load_tensormap,
|
||||
int subtile_idx=-1,
|
||||
bool wait_until_load_finishes = false) {
|
||||
int subtile_idx=-1) {
|
||||
using namespace cute;
|
||||
|
||||
// Indexing variables
|
||||
@ -595,12 +594,6 @@ public:
|
||||
// Post-loop fusion callback entry point
|
||||
pld_callbacks.end();
|
||||
|
||||
if (wait_until_load_finishes && did_load) {
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
|
||||
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
|
||||
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
|
||||
}
|
||||
|
||||
return load_pipe_producer_state;
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,274 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
|
||||
template <
|
||||
int CapacityBytes,
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class TileShapeMNK,
|
||||
class TileShapeSFA,
|
||||
class TileShapeSFB,
|
||||
int stages
|
||||
>
|
||||
constexpr int
|
||||
sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCount<stages> stage_count) {
|
||||
return stages;
|
||||
}
|
||||
|
||||
template <
|
||||
int CapacityBytes,
|
||||
class ElementA,
|
||||
class ElementB,
|
||||
class TileShapeMNK,
|
||||
class TileShapeSFA,
|
||||
class TileShapeSFB,
|
||||
int carveout_bytes
|
||||
>
|
||||
constexpr int
|
||||
sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
// For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t
|
||||
// Each stage include (CollectiveMma::SharedStorage)
|
||||
// 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage)
|
||||
// 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage)
|
||||
// 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed)
|
||||
constexpr auto mainloop_pipeline_bytes =
|
||||
sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage) +
|
||||
sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage);
|
||||
|
||||
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
|
||||
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
|
||||
constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{}));
|
||||
constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{}));
|
||||
|
||||
constexpr int stage_bytes =
|
||||
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
|
||||
static_cast<int>(mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes);
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ElementPairA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementPairB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class BuilderScheduleTag
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassBlockScaledTensorOp,
|
||||
ElementPairA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementPairB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
ClusterShape_MNK, // Static cluster shape (_1, _1, _1)
|
||||
StageCountType,
|
||||
BuilderScheduleTag,
|
||||
cute::enable_if_t<cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, BuilderScheduleTag> >
|
||||
>
|
||||
{
|
||||
using ElementSFA = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairA>::sf_type;
|
||||
using ElementSFB = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairB>::sf_type;
|
||||
using ElementA = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairA>::data_type;
|
||||
using ElementB = typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairB>::data_type;
|
||||
using ElementSF = ElementSFA;
|
||||
|
||||
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||
|
||||
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
|
||||
static_assert(detail::blockscaled::check_input_datatypes<BuilderScheduleTag, ElementPairA, ElementPairB, UmmaMajorA, UmmaMajorB>(), "Incorrect input types");
|
||||
|
||||
static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm<TileShape_MNK, ClusterShape_MNK, BuilderScheduleTag>();
|
||||
static constexpr auto Instr = detail::blockscaled::select_instr<ElementPairA, ElementPairB, ElementAccumulator, UmmaMajorA, UmmaMajorB, BuilderScheduleTag>();
|
||||
|
||||
using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma<ElementPairA, ElementPairB, ElementAccumulator,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
UmmaMajorA, UmmaMajorB, Instr, BuilderScheduleTag, is_2sm>::type;
|
||||
|
||||
static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8;
|
||||
|
||||
static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A<GmemLayoutATag>() && cutlass::gemm::detail::is_k_major_B<GmemLayoutBTag>()), "Only MMA.MXF8F6F4 supports non-K major inputs");
|
||||
|
||||
// Data type used by MMA instruction
|
||||
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA, UseMxf8f6f4>());
|
||||
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB, UseMxf8f6f4>());
|
||||
|
||||
static_assert(detail::sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement<ElementAMma, ElementBMma,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
GmemLayoutATag, GmemLayoutBTag, false /*is_sparse*/, is_2sm>(),
|
||||
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );
|
||||
|
||||
static constexpr uint32_t SFVectorSize = TiledMma::SFVecSize;
|
||||
|
||||
using ElementAMma_SmemAllocType = cute::conditional_t<UseMxf8f6f4, uint8_t, ElementAMma>;
|
||||
using ElementBMma_SmemAllocType = cute::conditional_t<UseMxf8f6f4, uint8_t, ElementBMma>;
|
||||
|
||||
// using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
|
||||
// ElementAMma, ElementBMma, ElementAccumulator,
|
||||
// decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
|
||||
// UmmaMajorA, UmmaMajorB, BuilderScheduleTag>());
|
||||
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<SFVectorSize>;
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
// Assigning 4 warps for mainloop load of B
|
||||
static constexpr int NumLoadThreadsCpAsync = 128;
|
||||
|
||||
|
||||
using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
|
||||
using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{}));
|
||||
|
||||
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>());
|
||||
|
||||
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(cutlass::sizeof_bits<ElementB>::value) * AlignmentB / 8>;
|
||||
using GmemCopyAtomB = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<AlignmentTypeB>, ElementB>;
|
||||
using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy<
|
||||
GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t<GmemLayoutBTag>,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>());
|
||||
|
||||
using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
|
||||
using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{}));
|
||||
using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{}));
|
||||
|
||||
using SmemLayoutAtomSFA = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFA(TiledMma{}, TileShape_MNK{}));
|
||||
using SmemLayoutAtomSFB = decltype(Sm1xxBlkScaledConfig::deduce_smem_layoutSFB(TiledMma{}, TileShape_MNK{}));
|
||||
using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{}));
|
||||
using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{}));
|
||||
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<GmemLayoutATag>;
|
||||
using StrideB = cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>;
|
||||
using InternalStrideA = cute::remove_pointer_t<StrideA>;
|
||||
using InternalStrideB = cute::remove_pointer_t<StrideB>;
|
||||
using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA());
|
||||
using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB());
|
||||
using LayoutSFA = cute::conditional_t<cute::is_same_v<InternalStrideA, StrideA>, InternalLayoutSFA, InternalLayoutSFA *>;
|
||||
using LayoutSFB = cute::conditional_t<cute::is_same_v<InternalStrideB, StrideB>, InternalLayoutSFB, InternalLayoutSFB *>;
|
||||
using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{}));
|
||||
using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{}));
|
||||
|
||||
static constexpr uint32_t AccumulatorPipelineStageCount = 2;
|
||||
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1;
|
||||
|
||||
// AccumulatorPipeline = PipelineUmmaAsync
|
||||
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
|
||||
// CLCPipeline = PipelineCLCFetchAsync
|
||||
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||
// CLC (scheduler) response
|
||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||
static constexpr auto KernelSmemCarveout = static_cast<int>( AccumulatorPipelineStorage +
|
||||
CLCPipelineStorage +
|
||||
CLCResponseStorage);
|
||||
// Reduce SMEM capacity available for buffers considering barrier allocations.
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
using SmemTileShape = cute::Shape<SmemShapeA_M, BlockTileB_N, SmemShapeA_K>;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync<
|
||||
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{});
|
||||
static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB.");
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled<
|
||||
PipelineStages,
|
||||
SchedulerPipelineStageCount,
|
||||
AccumulatorPipelineStageCount,
|
||||
ClusterShape_MNK>,
|
||||
TileShape_MNK,
|
||||
cute::tuple<ElementA, ElementSF>,
|
||||
StridePairA,
|
||||
cute::tuple<ElementB, ElementSF>,
|
||||
StridePairB,
|
||||
TiledMma,
|
||||
GmemTiledCopyPairA,
|
||||
SmemLayoutAtomsA,
|
||||
void,
|
||||
cute::identity,
|
||||
GmemTiledCopyPairB,
|
||||
SmemLayoutAtomsB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -120,6 +120,7 @@ struct CollectiveBuilder<
|
||||
BuilderScheduleTag,
|
||||
cute::enable_if_t<
|
||||
// Blockscaled Gemm
|
||||
(not cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, BuilderScheduleTag>) &&
|
||||
(cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag> ||
|
||||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|
||||
&&
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm100_common.inl"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
class BuilderScheduleTag
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm100,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
ClusterShape_MNK, // Static cluster shape (_1, _1, _1)
|
||||
StageCountType,
|
||||
BuilderScheduleTag,
|
||||
cute::enable_if_t<cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, BuilderScheduleTag> >
|
||||
>
|
||||
{
|
||||
static_assert(cute::is_static_v<TileShape_MNK>, "TileShape has to be static");
|
||||
|
||||
static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A<GmemLayoutATag>();
|
||||
static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B<GmemLayoutBTag>();
|
||||
|
||||
// Data type used by MMA instruction
|
||||
using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementA>());
|
||||
using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element<ElementB>());
|
||||
|
||||
using ElementAMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
||||
using ElementBMma_SmemAllocType = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
||||
|
||||
using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma<
|
||||
ElementAMma, ElementBMma, ElementAccumulator,
|
||||
decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK,
|
||||
UmmaMajorA, UmmaMajorB, BuilderScheduleTag>());
|
||||
|
||||
using AtomThrID = typename TiledMma::AtomThrID;
|
||||
|
||||
// ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K)
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
// ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K)
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}),
|
||||
cute::size<2>(TileShape_MNK{}))));
|
||||
|
||||
// Assigning 4 warps for mainloop load of B
|
||||
static constexpr int NumLoadThreadsCpAsync = 128;
|
||||
|
||||
|
||||
using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{}))));
|
||||
using SmemShapeA_K = decltype(cute::get<2>(TileShape_MNK{}));
|
||||
|
||||
|
||||
using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A(
|
||||
ClusterShape_MNK{}, AtomThrID{}));
|
||||
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorA, ElementAMma_SmemAllocType, SmemShapeA_M, SmemShapeA_K>());
|
||||
|
||||
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * AlignmentB>;
|
||||
using GmemCopyAtomB = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<AlignmentTypeB>, ElementB>;
|
||||
using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy<
|
||||
GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t<GmemLayoutBTag>,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{}));
|
||||
using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{}));
|
||||
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
|
||||
UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>());
|
||||
|
||||
static constexpr uint32_t AccumulatorPipelineStageCount = 2;
|
||||
// Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding.
|
||||
static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1;
|
||||
|
||||
// AccumulatorPipeline = PipelineUmmaAsync
|
||||
static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount>::SharedStorage);
|
||||
// CLCPipeline = PipelineCLCFetchAsync
|
||||
static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape_MNK>::SharedStorage);
|
||||
// CLC (scheduler) response
|
||||
static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize;
|
||||
// Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage
|
||||
static constexpr auto KernelSmemCarveout = static_cast<int>( AccumulatorPipelineStorage +
|
||||
CLCPipelineStorage +
|
||||
CLCResponseStorage);
|
||||
// Reduce SMEM capacity available for buffers considering barrier allocations.
|
||||
static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout;
|
||||
|
||||
using SmemTileShape = cute::Shape<SmemShapeA_M, BlockTileB_N, SmemShapeA_K>;
|
||||
using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage;
|
||||
|
||||
static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override<
|
||||
Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{});
|
||||
|
||||
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
|
||||
cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
|
||||
PipelineStages,
|
||||
SchedulerPipelineStageCount,
|
||||
AccumulatorPipelineStageCount,
|
||||
ClusterShape_MNK>,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
cutlass::gemm::TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
cutlass::gemm::TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
void,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
void,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -184,6 +184,7 @@ struct CollectiveBuilder<
|
||||
not cute::is_complex_v<ElementA> && not cute::is_complex_v<ElementB> &&
|
||||
// Dense Gemm / PtrArrayDenseGemm
|
||||
(
|
||||
(not cute::is_same_v<KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, BuilderScheduleTag>) &&
|
||||
(not cute::is_same_v<KernelWarpSpecialized1SmSm100, BuilderScheduleTag>) &&
|
||||
(cute::is_base_of_v<KernelScheduleSm100DenseGemm, BuilderScheduleTag> ||
|
||||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)) &&
|
||||
|
||||
@ -502,6 +502,7 @@ check_input_datatypes() {
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelScheduleBlockScaledGemmSm100>)
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelTmaWarpSpecialized1SmBlockScaledSm100>)
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelTmaWarpSpecialized2SmBlockScaledSm100>)
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100>)
|
||||
// SM100 BS ptr_array
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelSchedulePtrArrayBlockScaledGemmSm100>)
|
||||
|| (cute::is_same_v<BuilderScheduleTag, KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100>)
|
||||
@ -578,6 +579,8 @@ check_input_datatypes() {
|
||||
((SfVectorSizeA == 32 && cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_same_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelScheduleBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 32 && cute::is_base_of_v<KernelSchedulePtrArrayBlockScaledGemmSm100, BuilderScheduleTag>)
|
||||
|| (SfVectorSizeA == 64 && cute::is_base_of_v<KernelScheduleBlockScaledSparseGemmSm100, BuilderScheduleTag>)
|
||||
|
||||
@ -1069,10 +1069,10 @@ struct CollectiveBuilder<
|
||||
StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum> or
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum> or
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum>) and
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8Blockwise> or
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise> or
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8Blockwise> or
|
||||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise>) and
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutPairA, ElementB, GmemLayoutPairB>()
|
||||
>
|
||||
> {
|
||||
@ -1105,7 +1105,7 @@ struct CollectiveBuilder<
|
||||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, KernelScheduleType>);
|
||||
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now.");
|
||||
static_assert(IsFP8Input, "Warp Specialized gemm with FP8 Blockwise (Software) Scaling is only compatible with FP8 inputs version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
@ -1146,8 +1146,8 @@ struct CollectiveBuilder<
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
@ -49,6 +49,8 @@
|
||||
#include "cutlass/gemm/collective/builders/sm100_simt_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm120_mma_builder.inl"
|
||||
#include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl"
|
||||
|
||||
@ -65,6 +65,8 @@
|
||||
#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/collective/sm120_mma_tma.hpp"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,758 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/cluster.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/arch/memory.h"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one
|
||||
template <
|
||||
int Stages,
|
||||
int SchedulerPipelineStageCount,
|
||||
int AccumulatorPipelineStageCount,
|
||||
class ClusterShape, // Static cluster shape or dynamic (int, int, _1)
|
||||
class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK)
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
|
||||
Stages,
|
||||
SchedulerPipelineStageCount,
|
||||
AccumulatorPipelineStageCount,
|
||||
ClusterShape>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
using TiledMma = TiledMma_;
|
||||
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMma::ThrLayoutVMNK{})), _1, _1>;
|
||||
|
||||
// Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received
|
||||
static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA");
|
||||
static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1");
|
||||
|
||||
static_assert(size(typename TiledMma::AtomThrID{}) == 1);
|
||||
|
||||
using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized<
|
||||
Stages,
|
||||
SchedulerPipelineStageCount,
|
||||
AccumulatorPipelineStageCount,
|
||||
ClusterShape>;
|
||||
// TileShape refers to MmaTileShape to adapt for runtime cluster
|
||||
using TileShape = TileShape_;
|
||||
|
||||
CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})),
|
||||
"Static cluster shape used: TileShape should be evenly divided by TiledMma");
|
||||
|
||||
// Define A and B block shapes
|
||||
using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{}))));
|
||||
using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{}))));
|
||||
// using LoadShapeA_MK = decltype(select<0,2>(TileShape{}));
|
||||
using LoadShapeB_NK = decltype(select<1,2>(TileShape{}));
|
||||
|
||||
// CtaShape_MNK is queried from collective in all kernel layers
|
||||
using CtaShape_MNK = TileShape;
|
||||
|
||||
using ElementA = ElementA_;
|
||||
using ElementAMma = typename TiledMma::ValTypeA;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using ElementBMma = typename TiledMma::ValTypeB;
|
||||
using StrideB = StrideB_;
|
||||
|
||||
static constexpr bool IsRuntimeDataTypeA = cute::is_same_v<ElementA, cutlass::type_erased_dynamic_float8_t>;
|
||||
static constexpr bool IsRuntimeDataTypeB = cute::is_same_v<ElementB, cutlass::type_erased_dynamic_float8_t>;
|
||||
|
||||
static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB,
|
||||
"ElementA and ElementB should be both runtime or both static.");
|
||||
|
||||
static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB;
|
||||
|
||||
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using MainloopPipelineTMA = cutlass::PipelineTmaUmmaAsync<DispatchPolicy::Stages, ClusterShape, AtomThrShapeMNK>;
|
||||
using MainloopPipelineTMAState = typename MainloopPipelineTMA::PipelineState;
|
||||
|
||||
using MainloopPipelineCpAsync = cutlass::PipelineUmmaConsumerAsync<DispatchPolicy::Stages, AtomThrShapeMNK>;
|
||||
using MainloopPipelineCpAsyncState = typename MainloopPipelineCpAsync::PipelineState;
|
||||
|
||||
// static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count");
|
||||
static constexpr int NumLoadThreadsCpAsync = size(GmemTiledCopyB{});
|
||||
|
||||
static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)");
|
||||
static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)");
|
||||
static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM100 UMMA cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
// Tile along K mode first before tiling over MN. PIPE mode last as usual.
|
||||
// (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE)
|
||||
using SmemLayoutA = decltype(UMMA::tile_to_mma_shape(
|
||||
SmemLayoutAtomA{},
|
||||
append(MmaShapeA_MK{}, Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
|
||||
using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape(
|
||||
SmemLayoutAtomB{},
|
||||
append(MmaShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
using LoadSmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
append(LoadShapeB_NK{}, Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::UMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
|
||||
using TmaInternalElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, cutlass::tfloat32_t, ElementAMma>;
|
||||
|
||||
using SmemAllocTypeA = cute::conditional_t<cute::sizeof_bits_v<ElementAMma> < 8, uint8_t, ElementAMma>;
|
||||
using SmemAllocTypeB = cute::conditional_t<cute::sizeof_bits_v<ElementBMma> < 8, uint8_t, ElementBMma>;
|
||||
|
||||
using BitTypeElementA = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
||||
using BitTypeElementB = cute::uint_bit_t<cute::sizeof_bits_v<ElementB>>;
|
||||
|
||||
using ArrayElementA = cute::conditional_t<IsRuntimeDataTypeA, BitTypeElementA, ElementA>;
|
||||
using ArrayElementB = cute::conditional_t<IsRuntimeDataTypeB, BitTypeElementB, ElementB>;
|
||||
|
||||
using RuntimeDataTypeA = cute::conditional_t<IsRuntimeDataTypeA, cute::UMMA::MXF8F6F4Format, void*>;
|
||||
using RuntimeDataTypeB = cute::conditional_t<IsRuntimeDataTypeB, cute::UMMA::MXF8F6F4Format, void*>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128, _0> {
|
||||
cute::array_aligned<SmemAllocTypeA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::array_aligned<SmemAllocTypeB, cute::cosize_v<LoadSmemLayoutB>> smem_B;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorageTMA = typename MainloopPipelineTMA::SharedStorage;
|
||||
using PipelineStorageCpAsync = typename MainloopPipelineCpAsync::SharedStorage;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16, _0> {
|
||||
alignas(16) PipelineStorageTMA tma;
|
||||
alignas(16) PipelineStorageCpAsync cpasync;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
// Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them.
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v<ElementA>);
|
||||
|
||||
template <class AccTensor>
|
||||
struct TmemStorage {
|
||||
AccTensor accumulators;
|
||||
};
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ArrayElementA const* ptr_A{nullptr};
|
||||
StrideA dA{};
|
||||
ArrayElementB const* ptr_B{nullptr};
|
||||
StrideB dB{};
|
||||
RuntimeDataTypeA runtime_data_type_a{};
|
||||
RuntimeDataTypeB runtime_data_type_b{};
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}),
|
||||
make_tile(typename TiledMma::AtomThrID{})));
|
||||
|
||||
using TMA_A = decltype(make_tma_atom_A_sm100<TmaInternalElementA>(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(recast_ptr<TmaInternalElementA>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
TiledMma{},
|
||||
ClusterLayout_VMNK{})
|
||||
);
|
||||
|
||||
TMA_A tma_load_a;
|
||||
|
||||
ArrayElementB const* ptr_B{nullptr};
|
||||
StrideB dB{};
|
||||
|
||||
RuntimeDataTypeA runtime_data_type_a;
|
||||
RuntimeDataTypeB runtime_data_type_b;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CollectiveMma(Params const& params)
|
||||
: runtime_data_type_a_(params.runtime_data_type_a)
|
||||
, runtime_data_type_b_(params.runtime_data_type_b) {
|
||||
|
||||
observed_tma_load_a_ = ¶ms.tma_load_a;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(
|
||||
ProblemShape const& problem_shape,
|
||||
Arguments const& args,
|
||||
[[maybe_unused]] void* workspace,
|
||||
cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) {
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = recast_ptr<TmaInternalElementA>(args.ptr_A);
|
||||
auto ptr_B = recast_ptr<ElementBMma>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
|
||||
auto cluster_layout_vmnk = tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{}));
|
||||
|
||||
typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100<TmaInternalElementA>(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
TiledMma{},
|
||||
cluster_layout_vmnk);
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
args.ptr_B,
|
||||
args.dB,
|
||||
args.runtime_data_type_a,
|
||||
args.runtime_data_type_b
|
||||
};
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4<TiledMma, ElementA, ElementB>();
|
||||
constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits<ElementA, IsF8F6F4>();
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits<ElementA>::value;
|
||||
|
||||
bool implementable = true;
|
||||
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
|
||||
implementable = implementable && cutlass::detail::check_alignment<GmemTiledCopyB::NumValSrc>(cute::make_shape(N,K,L), StrideB{});
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n");
|
||||
}
|
||||
|
||||
return implementable;
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE void
|
||||
prefetch_tma_descriptors() {
|
||||
cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Construct A Single Stage's Accumulator Shape
|
||||
CUTLASS_DEVICE static
|
||||
auto
|
||||
partition_accumulator_shape() {
|
||||
auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N)
|
||||
|
||||
return acc_shape;
|
||||
}
|
||||
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE static
|
||||
auto
|
||||
slice_accumulator(TmemStorage tmem_storage, int stage) {
|
||||
return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage));
|
||||
}
|
||||
|
||||
template <class EpilogueTile, bool IsOverlappingAccum = false>
|
||||
CUTLASS_DEVICE static
|
||||
auto
|
||||
init_tmem_tensors(EpilogueTile epi_tile) {
|
||||
TiledMma tiled_mma;
|
||||
auto acc_shape = partition_accumulator_shape();
|
||||
// ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue.
|
||||
Tensor accumulators = cutlass::detail::make_sm100_accumulator<AccumulatorPipelineStageCount, IsOverlappingAccum>(
|
||||
tiled_mma, acc_shape, EpilogueTile{});
|
||||
TmemStorage<decltype(accumulators)> tmem_storage;
|
||||
tmem_storage.accumulators = accumulators;
|
||||
return tmem_storage;
|
||||
}
|
||||
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE static
|
||||
void
|
||||
set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) {
|
||||
tmem_storage.accumulators.data() = tmem_base_addr;
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load.
|
||||
/// Return tuple element contain
|
||||
/// gA_mkl - The tiled tensor for input A
|
||||
/// gB_nkl - The tiled tensor for input B
|
||||
/// tAsA - partitioned smem tensor for A
|
||||
/// tBsB - partitioned smem tensor for B
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init_tma(
|
||||
ProblemShape_MNKL const& problem_shape_MNKL,
|
||||
TensorStorage& shared_tensors) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA
|
||||
Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L));
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l)
|
||||
|
||||
ThrMMA cta_mma = TiledMma{}.get_slice(0);
|
||||
Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l)
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
|
||||
// Define the CTA-in-cluster Layout and Coord
|
||||
Layout cta_layout_mnk = make_layout(ClusterShape{});
|
||||
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{}));
|
||||
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0);
|
||||
|
||||
// Project the cta_layout for tma_a along the n-modes
|
||||
auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_,
|
||||
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
|
||||
group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl));
|
||||
|
||||
return cute::make_tuple(
|
||||
shape<3>(gA_mkl), // for scheduler
|
||||
tAgA_mkl, tAsA // for input tensor values
|
||||
);
|
||||
}
|
||||
|
||||
template <class ProblemShape_MNKL, class TileScheduler>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init_cpasync(
|
||||
ProblemShape_MNKL const& problem_shape_MNKL,
|
||||
Params const& params,
|
||||
TensorStorage& shared_tensors,
|
||||
TileScheduler const& scheduler,
|
||||
typename TileScheduler::WorkTileInfo const& work_tile_info) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l)
|
||||
// Partition for cpasync
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
// Build the coordinate tensors with the same shape as input matrices
|
||||
Tensor cB_nk = make_identity_tensor(make_shape(N,K));
|
||||
// Slice the coordinate tensors in the same way as A/B tensor partitioning
|
||||
Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k)
|
||||
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{});
|
||||
|
||||
GmemTiledCopyB gmem_to_smem_b_tiled_copy;
|
||||
|
||||
int thread_idx = threadIdx.x % NumLoadThreadsCpAsync;
|
||||
auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx);
|
||||
|
||||
return cute::make_tuple(
|
||||
gB_nkl, cgB_nk, sB,
|
||||
gmem_to_smem_b_tiled_copy, thr_copy_b);
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for mma compute.
|
||||
template <class TmemStorage>
|
||||
CUTLASS_DEVICE auto
|
||||
mma_init(
|
||||
Params const& params,
|
||||
[[maybe_unused]] TmemStorage tmem_storage,
|
||||
// [[maybe_unused]] cute::tuple<cute::Tensor<FrgEngine, FrgLayout>, cute::Tensor<FrgEngine, FrgLayout>> const& accumulators_pair,
|
||||
TensorStorage& shared_tensors) const {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors" for A and B matrices
|
||||
Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<3>(sB));
|
||||
|
||||
TiledMma tiled_mma;
|
||||
|
||||
if constexpr (IsRuntimeDataType) {
|
||||
// Update instruction descriptor according to runtime argument.
|
||||
// Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe.
|
||||
tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111;
|
||||
tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111;
|
||||
}
|
||||
|
||||
return cute::make_tuple(tiled_mma, tCrA, tCrB);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class KTileCount,
|
||||
class GTensorPartitionedA,
|
||||
class STensorA,
|
||||
class TileCoordMNKL,
|
||||
class KTileIterator
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
load_tma(
|
||||
MainloopPipelineTMA mainloop_pipeline,
|
||||
MainloopPipelineTMAState mainloop_pipe_producer_state,
|
||||
cute::tuple<KTileCount,
|
||||
GTensorPartitionedA,
|
||||
STensorA> const& load_inputs,
|
||||
TileCoordMNKL const& cta_coord_mnkl,
|
||||
KTileIterator k_tile_iter, int k_tile_count) {
|
||||
|
||||
// Unpack from load_inputs
|
||||
KTileCount k_tiles = get<0>(load_inputs);
|
||||
GTensorPartitionedA tAgA_mkl = get<1>(load_inputs);
|
||||
STensorA tAsA = get<2>(load_inputs);
|
||||
|
||||
// slice out the work coord from partitioned tensors
|
||||
Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl));
|
||||
|
||||
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
||||
|
||||
// Issue the Mainloop loads
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
// LOCK mainloop_pipe_producer_state for _writing_
|
||||
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
|
||||
|
||||
using BarrierType = typename MainloopPipelineTMA::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
|
||||
|
||||
int write_stage = mainloop_pipe_producer_state.index();
|
||||
++mainloop_pipe_producer_state;
|
||||
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state);
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage));
|
||||
}
|
||||
|
||||
--k_tile_count;
|
||||
++k_tile_iter;
|
||||
}
|
||||
|
||||
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
|
||||
}
|
||||
|
||||
|
||||
template <
|
||||
// class GTensorB,
|
||||
// class CTensorB,
|
||||
// class STensorB,
|
||||
// class ProblemShape_MNKL,
|
||||
// class TiledCopyB,
|
||||
// class ThreadCopyB,
|
||||
class TileCoordMNKL,
|
||||
class KTileIterator,
|
||||
class ProblemShape_MNKL,
|
||||
class... TParams
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
load_cpasync(
|
||||
Params const& params,
|
||||
MainloopPipelineCpAsync mainloop_pipeline,
|
||||
MainloopPipelineCpAsyncState mainloop_pipe_producer_state,
|
||||
cute::tuple<TParams...> const& load_inputs,
|
||||
TileCoordMNKL const& cta_coord_mnkl,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
ProblemShape_MNKL effective_shape
|
||||
) {
|
||||
|
||||
// Unpack from load_inputs
|
||||
// GTensorB tBgB_nkl = get<0>(load_inputs);
|
||||
// CTensorB cgB_nk = get<1>(load_inputs);
|
||||
// STensorB sB = get<2>(load_inputs);
|
||||
// ProblemShape_MNKL problem_shape_MNKL = get<3>(load_inputs);
|
||||
// TiledCopyB gmem_to_smem_b_tiled_copy = get<4>(load_inputs);
|
||||
// ThreadCopyB thr_copy_b = get<5>(load_inputs);
|
||||
|
||||
auto [
|
||||
tBgB_nkl, cgB_nk, sB,
|
||||
// problem_shape_MNKL,
|
||||
gmem_to_smem_b_tiled_copy, thr_copy_b] = load_inputs;
|
||||
|
||||
auto [M,N,K,L] = effective_shape;
|
||||
|
||||
// Slice out the work coord from partitioned tensors
|
||||
Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl));
|
||||
// Repeat slicing out coordinate tensor exactly the same as input tensor does
|
||||
Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _);
|
||||
|
||||
auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative
|
||||
|
||||
Tensor gB = gB_in;
|
||||
Tensor cB = cgB_nk_in;
|
||||
|
||||
auto tBgB = thr_copy_b.partition_S(gB);
|
||||
auto tBsB = thr_copy_b.partition_D(sB);
|
||||
|
||||
// Allocate predicate tensors for n
|
||||
Tensor tBpB = make_tensor<bool>(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{});
|
||||
Tensor tBcB_nk = thr_copy_b.partition_S(cgB_nk_in);
|
||||
Tensor tBcB = thr_copy_b.partition_S(cB);
|
||||
|
||||
// Copy gmem to smem for *k_tile_iter, predicating for k residue
|
||||
Tensor tBgBk = tBgB(_,_,_,*k_tile_iter);
|
||||
|
||||
// Repeating on predicators with the same operations on tBgB
|
||||
Tensor tBcBk = tBcB(_,_,_,*k_tile_iter);
|
||||
|
||||
// Set predicates for n bounds
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < size<0>(tBpB); ++n) {
|
||||
tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N
|
||||
}
|
||||
|
||||
// we will process the last tile after the mainloop
|
||||
if (k_residue != 0) {
|
||||
--k_tile_count;
|
||||
}
|
||||
|
||||
// Issue the Mainloop loads
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
|
||||
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
|
||||
int write_stage = mainloop_pipe_producer_state.index();
|
||||
|
||||
copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
--k_tile_count;
|
||||
++k_tile_iter;
|
||||
++mainloop_pipe_producer_state;
|
||||
}
|
||||
|
||||
// last tile with predication on k to account for residue
|
||||
// For performance consideration,
|
||||
// this predicated block for K-tail is only activated when there is k-residue
|
||||
if (k_residue != 0) {
|
||||
// LOCK mainloop_pipe_producer_state for _writing_
|
||||
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
|
||||
int write_stage = mainloop_pipe_producer_state.index();
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k = 0; k < size<2>(tBsB); ++k) {
|
||||
if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K
|
||||
copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgB(_,_,k,*k_tile_iter), tBsB(_,_,k,write_stage));
|
||||
}
|
||||
else {
|
||||
clear(tBsB(_,_,k,write_stage));
|
||||
}
|
||||
}
|
||||
++k_tile_iter;
|
||||
--k_tile_count;
|
||||
|
||||
// UNLOCK mainloop_pipe_producer_state
|
||||
mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive);
|
||||
|
||||
// Advance mainloop_pipe_producer_state
|
||||
++mainloop_pipe_producer_state;
|
||||
}
|
||||
|
||||
return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter);
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) {
|
||||
// Issue the epilogue waits
|
||||
// This helps avoid early exit of ctas in Cluster
|
||||
// Waits for all stages to either be released (all
|
||||
// Consumer UNLOCKs), or if the stage was never used
|
||||
// then would just be acquired since the phase was
|
||||
// still inverted from make_producer_start_state
|
||||
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
|
||||
}
|
||||
CUTLASS_DEVICE void
|
||||
load_tail_cpasync(MainloopPipelineCpAsync mainloop_pipeline, MainloopPipelineCpAsyncState mainloop_pipe_producer_state) {
|
||||
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class AccumulatorPipeline,
|
||||
class FrgEngine, class FrgLayout,
|
||||
class FragmentA, class FragmentB,
|
||||
class CtaTileCoord
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
mma(cute::tuple<MainloopPipelineTMA,
|
||||
MainloopPipelineCpAsync,
|
||||
AccumulatorPipeline> pipelines,
|
||||
cute::tuple<MainloopPipelineTMAState,
|
||||
MainloopPipelineCpAsyncState,
|
||||
typename AccumulatorPipeline::PipelineState> pipeline_states,
|
||||
cute::tuple<cute::Tensor<FrgEngine, FrgLayout>> const& accumulators_pair,
|
||||
cute::tuple<TiledMma, FragmentA, FragmentB> const& mma_inputs,
|
||||
CtaTileCoord cta_tile_coord,
|
||||
int k_tile_count
|
||||
) {
|
||||
static_assert(is_tmem<FrgEngine>::value, "Accumulator must be tmem resident.");
|
||||
static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)");
|
||||
auto accumulators = get<0>(accumulators_pair);
|
||||
auto [tiled_mma, tCrA, tCrB] = mma_inputs;
|
||||
|
||||
auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines;
|
||||
auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states;
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::Zero;
|
||||
// Wait for tmem accumulator buffer to become empty with a flipped phase
|
||||
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (k_tile_count > 0) {
|
||||
mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state);
|
||||
mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state);
|
||||
|
||||
int read_stage_tma = mainloop_pipe_tma_consumer_state.index();
|
||||
int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index();
|
||||
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M) x (V,N) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators);
|
||||
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state);
|
||||
mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state);
|
||||
--k_tile_count;
|
||||
++mainloop_pipe_tma_consumer_state;
|
||||
++mainloop_pipe_cpasync_consumer_state;
|
||||
}
|
||||
|
||||
return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
typename Params::TMA_A const* observed_tma_load_a_{nullptr};
|
||||
RuntimeDataTypeA runtime_data_type_a_{};
|
||||
RuntimeDataTypeB runtime_data_type_b_{};
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -248,7 +248,7 @@ public:
|
||||
|
||||
using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync<
|
||||
DispatchPolicy::Load2TransformPipelineStageCount,
|
||||
ClusterShape,
|
||||
ClusterShape,
|
||||
AtomThrShapeMNK>;
|
||||
using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState;
|
||||
|
||||
@ -316,7 +316,7 @@ public:
|
||||
|
||||
using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape(
|
||||
SmemLayoutAtomACompute{},
|
||||
append(CtaShapeA_MK{}, Int<DispatchPolicy::Load2TransformPipelineStageCount>{}),
|
||||
append(CtaShapeA_MK{}, Int<DispatchPolicy::Transform2MmaPipelineStageCount>{}),
|
||||
(cute::conditional_t<cutlass::gemm::detail::is_mn_major<StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})));
|
||||
|
||||
using SmemLayoutB = decltype(UMMA::tile_to_mma_shape(
|
||||
@ -385,7 +385,7 @@ public:
|
||||
|
||||
struct TensorStorageUntransformed {
|
||||
alignas(512) cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
|
||||
cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
alignas(1024) cute::ArrayEngine<ElementB, cute::cosize_v<SmemLayoutB>> smem_B;
|
||||
cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
|
||||
cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
|
||||
};
|
||||
|
||||
@ -73,7 +73,7 @@ template <
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<Stages, ClusterShape, KernelSchedule>,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StridePairA_,
|
||||
@ -92,7 +92,7 @@ struct CollectiveMma<
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<Stages, ClusterShape, KernelSchedule>;
|
||||
using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = cute::tuple_element_t<0,StridePairA_>;
|
||||
@ -382,8 +382,6 @@ struct CollectiveMma<
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), InternalStrideA{});
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), InternalStrideB{});
|
||||
// We expect full tiles in K
|
||||
implementable = implementable && K % size<2>(TileShape{}) == 0;
|
||||
}
|
||||
}
|
||||
|
||||
@ -824,16 +822,13 @@ struct CollectiveMma<
|
||||
// Prologue GMMAs
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
// fence_operand();
|
||||
GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA));
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
{
|
||||
if (k_tile_count > 0) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
// Load per block scale values from shared memory to registers
|
||||
@ -977,7 +972,7 @@ struct CollectiveMma<
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
if (k_tile_count) {
|
||||
if (k_tile_count > 0) {
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
@ -1072,9 +1067,11 @@ struct CollectiveMma<
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// The pipeline is not released in the first iteration
|
||||
smem_pipe_release.advance(k_tile_count - 1);
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
if (k_tile_count > 0) {
|
||||
// The pipeline is not released in the first iteration
|
||||
smem_pipe_release.advance(k_tile_count - 1);
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -73,7 +73,7 @@ template <
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule>,
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StridePairA_,
|
||||
@ -91,7 +91,7 @@ struct CollectiveMma<
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<Stages, ClusterShape, KernelSchedule>;
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = cute::tuple_element_t<0,StridePairA_>;
|
||||
@ -391,12 +391,6 @@ struct CollectiveMma<
|
||||
implementable = false;
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size doesn't meet the minimum alignment requirements for using TMA to load scale B.\n");
|
||||
}
|
||||
|
||||
// We expect full tiles in K
|
||||
if (K % size<2>(TileShape{}) != 0) {
|
||||
implementable = false;
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem size K is incompatible with tile size.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
|
||||
@ -127,10 +127,15 @@ struct KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8Blockwise: KernelTmaWarpSpecializedCooperative { };
|
||||
struct KernelTmaWarpSpecializedPingpongFP8Blockwise: KernelTmaWarpSpecializedPingpong { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise: KernelPtrArrayTmaWarpSpecializedCooperative { };
|
||||
struct KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise: KernelPtrArrayTmaWarpSpecializedPingpong { };
|
||||
|
||||
using KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum = KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
|
||||
// Policies to opt into mixed type GEMMs
|
||||
struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { };
|
||||
@ -319,17 +324,17 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8
|
||||
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
|
||||
// For FP8 kernels with Block Scaling
|
||||
// For FP8 kernels with Blockwise (Software) Scaling
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
|
||||
class KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8Blockwise
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum> ||
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperativeFP8Blockwise> ||
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedPingpongFP8Blockwise>,
|
||||
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
|
||||
};
|
||||
|
||||
@ -411,15 +416,15 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput {
|
||||
template<
|
||||
int Stages_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>,
|
||||
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum
|
||||
class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise
|
||||
>
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling
|
||||
struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise
|
||||
: MainloopSm90ArrayTmaGmmaWarpSpecialized<Stages_, ClusterShape_, KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_any_of_v<
|
||||
KernelSchedule,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise
|
||||
>,
|
||||
"KernelSchedule must be one of the warp specialized FP8 block scale policies");
|
||||
};
|
||||
@ -440,6 +445,15 @@ struct KernelWarpSpecializedSm100 final {
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_
|
||||
>
|
||||
struct KernelMixedTmaCpAsyncWarpSpecializedSm100 final {
|
||||
static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
|
||||
static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
|
||||
};
|
||||
|
||||
template<
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_
|
||||
@ -653,7 +667,7 @@ template<
|
||||
class KernelSchedule
|
||||
>
|
||||
struct HasAuxiliaryLoad<
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling<
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecializedBlockwise<
|
||||
Stages,
|
||||
ClusterShape,
|
||||
KernelSchedule
|
||||
@ -666,7 +680,7 @@ template<
|
||||
class KernelSchedule
|
||||
>
|
||||
struct HasAuxiliaryLoad<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockwiseFP8<
|
||||
Stages,
|
||||
ClusterShape,
|
||||
KernelSchedule
|
||||
@ -700,6 +714,7 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy
|
||||
struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder
|
||||
struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder
|
||||
struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA
|
||||
struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SM100 Ptr-Array Dense GEMM Dispatch Policies
|
||||
@ -795,6 +810,8 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1
|
||||
struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { };
|
||||
struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { };
|
||||
struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { };
|
||||
struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -950,6 +967,34 @@ struct MainloopSm100UmmaCpAsyncWarpSpecialized {
|
||||
using Schedule = KernelWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
||||
};
|
||||
|
||||
template<
|
||||
int Stages_,
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>
|
||||
>
|
||||
struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm100;
|
||||
using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
||||
constexpr static bool IsOverlappingAccum = false;
|
||||
};
|
||||
|
||||
template<
|
||||
int Stages_,
|
||||
int SchedulerPipelineStageCount_,
|
||||
int AccumulatorPipelineStageCount_,
|
||||
class ClusterShape_ = Shape<_1,_1,_1>
|
||||
>
|
||||
struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled {
|
||||
constexpr static int Stages = Stages_;
|
||||
using ClusterShape = ClusterShape_;
|
||||
using ArchTag = arch::Sm100;
|
||||
using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100<SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_>;
|
||||
constexpr static bool IsOverlappingAccum = false;
|
||||
};
|
||||
|
||||
// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule
|
||||
template<
|
||||
int Stages_,
|
||||
|
||||
@ -79,6 +79,16 @@ struct GroupProblemShape {
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape_, class MaxProblemShape_>
|
||||
struct MoEProblemShape {
|
||||
using UnderlyingProblemShape = ProblemShape_;
|
||||
using MaxProblemShape = MaxProblemShape_;
|
||||
|
||||
UnderlyingProblemShape problem_shape;
|
||||
MaxProblemShape max_problem_shape;
|
||||
};
|
||||
|
||||
|
||||
template <class ProblemShape_>
|
||||
class ArrayProblemShape {
|
||||
public:
|
||||
@ -120,4 +130,14 @@ private:
|
||||
UnderlyingProblemShape problem_shape_{};
|
||||
};
|
||||
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<class T>
|
||||
struct is_moe_problem_shape : cute::false_type {};
|
||||
template<class T, class U>
|
||||
struct is_moe_problem_shape<cutlass::gemm::MoEProblemShape<T,U>> : cute::true_type {};
|
||||
|
||||
}
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
|
||||
@ -73,6 +73,7 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp"
|
||||
#include "cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -240,6 +240,27 @@ public:
|
||||
void
|
||||
fixup(WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t, uint32_t = 1) const { }
|
||||
|
||||
template <
|
||||
bool IsComplex,
|
||||
class TiledMma,
|
||||
class AccEngine,
|
||||
class AccLayout,
|
||||
class AccumulatorPipeline,
|
||||
class AccumulatorPipelineState,
|
||||
class CopyOpT2R
|
||||
>
|
||||
CUTLASS_DEVICE
|
||||
AccumulatorPipelineState
|
||||
fixup(
|
||||
TiledMma const& ,
|
||||
WorkTileInfo const&,
|
||||
cute::Tensor<AccEngine, AccLayout>&,
|
||||
AccumulatorPipeline,
|
||||
AccumulatorPipelineState acc_pipe_consumer_state,
|
||||
CopyOpT2R) const {
|
||||
return acc_pipe_consumer_state;
|
||||
}
|
||||
|
||||
template <class ProblemShape, class ElementAccumulator>
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, uint32_t, uint32_t = 1, uint32_t = 1) {
|
||||
|
||||
@ -991,7 +991,7 @@ public:
|
||||
mainloop_sf_pipeline,
|
||||
mainloop_sf_pipe_producer_state,
|
||||
load_inputs,
|
||||
cta_coord_mnkl,
|
||||
cta_coord_mnk,
|
||||
k_tile_iter_next, k_tile_count - k_tile_prologue,
|
||||
false, /* did_batch_change - prologue loads handle tensormap acquire */
|
||||
enable_prefetch ? k_tile_count - k_tile_prologue : 0
|
||||
|
||||
@ -831,8 +831,6 @@ public:
|
||||
collective_epilogue.template tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
|
||||
}
|
||||
|
||||
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
@ -843,8 +841,7 @@ public:
|
||||
lane_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.reduction_subtile_idx(),
|
||||
wait
|
||||
work_tile_info.reduction_subtile_idx()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@ -869,8 +869,6 @@ public:
|
||||
collective_epilogue.template tensormaps_fence_acquire<IsEpiLoad>(epi_load_tensormap);
|
||||
}
|
||||
|
||||
bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx;
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline,
|
||||
epi_load_pipe_producer_state,
|
||||
@ -881,8 +879,7 @@ public:
|
||||
lane_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
epi_load_tensormap,
|
||||
work_tile_info.reduction_subtile_idx(),
|
||||
wait
|
||||
work_tile_info.reduction_subtile_idx()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user