v3.9 update (#2213)

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-04-02 23:10:16 -07:00
committed by GitHub
parent 6f4921858b
commit 79fc51f4b8
72 changed files with 19875 additions and 459 deletions

View File

@ -340,7 +340,7 @@ public:
base_args.epilogue.thread,
reinterpret_cast<const ElementC*>(tensor_c_iter.data()),
tensor_c_iter.stride(),
reinterpret_cast<const ElementD*>(tensor_d_iter.data()),
reinterpret_cast<ElementD*>(tensor_d_iter.data()),
tensor_d_iter.stride()
};

View File

@ -82,7 +82,7 @@ struct DistributedGemmKernelWrapper<
using BaseArguments = typename BaseKernel::Arguments;
using BaseParams = typename BaseKernel::Params;
static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
//static_assert(BaseKernel::ArchTag::kMinComputeCapability == 90, "DistGEMM only supports Hopper GEMMs for now.");
static_assert(not cute::is_same_v<typename BaseKernel::ElementC, void>, "DistributedGEMM epilogues must have a source.");
using ElementFlag = uint32_t;

View File

@ -189,6 +189,100 @@ template<
bool Is2sm = false
>
constexpr bool sm1xx_gemm_check_for_f8f6f4_mix8bit_requirement(){
// * 1SM Dense
// * A_K(t) : TileShape_K % 128 == 0
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 128 == 0
// * B_K(n) : TileSize_K % 128 == 0
//
// * 2SM Dense
// * A_K(t) : TileShape_K % 128 == 0
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 256 == 0
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
// full tile_n needs to be 256 elts aligned
// * B_K(n) : TileShape_K % 128 == 0
//
// * 1SM Sparse
// * A_K(t) : TileShape_K % 256 == 0
// num of physical elems needs to be 128 elts aligned
// num of logical elems needs to be 256 elts aligned
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 128 == 0
// * B_K(n) : TileSize_K % 128 == 0
//
// * 2SM Sparse
// * A_K(t) : TileShape_K % 256 == 0
// num of physical elems needs to be 128 elts aligned
// num of logical elems needs to be 256 elts aligned
// * A_M(n) : TileShape_M % 128 == 0
// * B_N(t) : TileSize_N % 256 == 0
// each sm load half the data along tile_n (split vertically), each sm needs to be 128 elts aligned.
// full tile_n needs to be 256 elts aligned
// * B_K(n) : TileShape_K % 128 == 0
//
// * Valid TileShape_MNK Dense
// * Notation:
// mma_instruction_tile_shape-cta_tile_shape
// * s128x128x64
// s128x128x32_128x128x128_nn YES
// s128x128x32_128x128x128_nt YES
// s128x128x32_128x128x128_tn YES
// s128x128x32_128x128x128_tt YES
// * s128x256x64
// s128x256x32_128x256x128_nn YES
// s128x256x32_128x256x128_nt YES
// s128x256x32_128x256x128_tn YES
// s128x256x32_128x256x128_tt YES
// * s256x128x64
// s256x128x32_256x128x128_nn YES
// s256x128x32_256x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x32_256x128x128_tn YES
// s256x128x32_256x128x128_tt NO (2SM B_N TileSize_N % 256 != 0)
// * s256x256x64
// s256x256x32_256x256x128_nn YES
// s256x256x32_256x256x128_nt YES
// s256x256x32_256x256x128_tn YES
// s256x256x32_256x256x128_tt YES
//
// * Valid TileShape_MNK Sparse
// * s128x128x64
// s128x128x64_128x128x128_nn YES
// s128x128x64_128x128x128_nt YES
// s128x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
// s128x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
// s128x128x64_128x128x256_nn YES
// s128x128x64_128x128x256_nt YES
// s128x128x64_128x128x256_tn YES
// s128x128x64_128x128x256_tt YES
// * s128x256x64
// s128x256x64_128x256x128_nn YES
// s128x256x64_128x256x128_nt YES
// s128x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
// s128x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
// s128x256x64_128x256x256_nn YES
// s128x256x64_128x256x256_nt YES
// s128x256x64_128x256x256_tn YES
// s128x256x64_128x256x256_tt YES
// * s256x128x64
// s256x128x64_128x128x128_nn YES
// s256x128x64_128x128x128_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x64_128x128x128_tn NO (A_K TileShape_K % 256 != 0)
// s256x128x64_128x128x128_tt NO (A_K TileShape_K % 256 != 0)
// s256x128x64_128x128x256_nn YES
// s256x128x64_128x128x256_nt NO (2SM B_N TileSize_N % 256 != 0)
// s256x128x64_128x128x256_tn YES
// s256x128x64_128x128x256_tt NO (2SM B_N TileSize_N % 256 != 0)
// * s256x256x64
// s256x256x64_128x256x128_nn YES
// s256x256x64_128x256x128_nt YES
// s256x256x64_128x256x128_tn NO (A_K TileShape_K % 256 != 0)
// s256x256x64_128x256x128_tt NO (A_K TileShape_K % 256 != 0)
// s256x256x64_128x256x256_nn YES
// s256x256x64_128x256x256_nt YES
// s256x256x64_128x256x256_tn YES
// s256x256x64_128x256x256_tt YES
[[maybe_unused]] constexpr int TileShape_M = Is2sm ? size<0>(TileShape_MNK{}) / 2 : size<0>(TileShape_MNK{});
[[maybe_unused]] constexpr int TileShape_N = size<1>(TileShape_MNK{});
[[maybe_unused]] constexpr int TileShape_K = size<2>(TileShape_MNK{});

View File

@ -432,6 +432,10 @@ public:
init_M = get<0>(problem_shape_MNK);
init_N = get<1>(problem_shape_MNK);
init_K = get<2>(problem_shape_MNK);
if constexpr (SwapAB) {
init_M = get<1>(problem_shape_MNK);
init_N = get<0>(problem_shape_MNK);
}
if constexpr (not SwapAB) {
dA = args.dA;
@ -491,7 +495,7 @@ public:
: args_setup(args.ptr_A, args.ptr_B);
}
else if constexpr (ModeHasScales) {
auto scale_k = 1;
auto scale_k = ceil_div(init_K, args.chunk_size);
ElementScale const* ptr_S = reinterpret_cast<ElementScale const*>(args.ptr_S);
StrideScale dS{};
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS));
@ -595,7 +599,7 @@ public:
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.chunk_size - 1) / args.chunk_size;
const int scale_k = ceil_div(K, args.chunk_size);
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0));
@ -659,14 +663,15 @@ public:
return cute::make_tuple(gA_mkl, gB_nkl);
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
auto scale_k = mainloop_params.scale_k;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(scale_mn,scale_k,L));
Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l)
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(scale_mn,scale_k,L));
Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl);
}
@ -1217,8 +1222,8 @@ public:
Params const& mainloop_params,
int32_t next_group,
ProblemShape_MNKL problem_shape_mnkl) {
const uint32_t M = get<0>(problem_shape_mnkl);
const uint32_t N = get<1>(problem_shape_mnkl);
const uint32_t M = (SwapAB? get<1>(problem_shape_mnkl) : get<0>(problem_shape_mnkl));
const uint32_t N = (SwapAB? get<0>(problem_shape_mnkl) : get<1>(problem_shape_mnkl));
const uint32_t K = get<2>(problem_shape_mnkl);
// Replace all dims for consistency
@ -1245,14 +1250,14 @@ public:
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
NonVoidElementScale const* ptr_S = nullptr;
auto scale_k = 1;
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale,
prob_shape_scale, prob_stride_scale);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
ElementZero const* ptr_Z = nullptr;
auto scale_k = 1;
auto scale_k = ceil_div(K, mainloop_params.chunk_size);
Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]);
cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero,
prob_shape_zero, prob_stride_zero);

View File

@ -426,7 +426,7 @@ public:
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
}
else if constexpr (ModeHasScales) {
auto scale_k = (K + args.group_size - 1) / args.group_size;
auto scale_k = ceil_div(K, args.group_size);
ElementScale const* ptr_S = args.ptr_S;
StrideScale dS = args.dS;
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS));
@ -483,7 +483,7 @@ public:
}
else if constexpr (ModeHasScales) {
const int scale_mn = SwapAB ? N : M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
const int scale_k = ceil_div(K, args.group_size);
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));

View File

@ -622,6 +622,11 @@ public:
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
impl_.producer_expect_transaction(state, transaction_bytes);
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(PipelineState state, uint32_t bytes) {

View File

@ -452,6 +452,11 @@ public:
return producer_get_barrier(state.index());
}
CUTLASS_DEVICE
void producer_expect_transaction(PipelineState state, uint32_t transaction_bytes) {
producer_expect_transaction(state.index(), transaction_bytes);
}
////////////////////
// Consumer APIs
////////////////////
@ -519,6 +524,14 @@ private:
#endif
}
CUTLASS_DEVICE
void producer_expect_transaction(uint32_t stage, uint32_t transaction_bytes) {
detail::pipeline_check_is_producer(params_.role);
if (params_.is_leader) {
full_barrier_ptr_[stage].expect_transaction(transaction_bytes);
}
}
// NOP for TMA based mainloop
CUTLASS_DEVICE
void producer_commit(uint32_t stage, uint32_t bytes) {