|
|
|
|
@ -182,6 +182,7 @@ public:
|
|
|
|
|
using InternalSmemLayoutAtomB = cute::conditional_t<!SwapAB, SmemLayoutAtomB, SmemLayoutAtomA>;
|
|
|
|
|
using InternalSmemCopyAtomA = cute::conditional_t<!SwapAB, SmemCopyAtomA, SmemCopyAtomB>;
|
|
|
|
|
using InternalSmemCopyAtomB = cute::conditional_t<!SwapAB, SmemCopyAtomB, SmemCopyAtomA>;
|
|
|
|
|
|
|
|
|
|
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
|
|
|
|
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
|
|
|
|
|
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
|
|
|
|
@ -228,14 +229,25 @@ public:
|
|
|
|
|
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape.");
|
|
|
|
|
|
|
|
|
|
// Tile along modes in a way that maximizes the TMA box size.
|
|
|
|
|
using SmemLayoutA = decltype(tile_to_shape(
|
|
|
|
|
InternalSmemLayoutAtomA{},
|
|
|
|
|
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
|
|
|
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
|
|
|
using SmemLayoutB = decltype(tile_to_shape(
|
|
|
|
|
InternalSmemLayoutAtomB{},
|
|
|
|
|
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
|
|
|
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
|
|
|
|
|
|
|
|
|
template<class LayoutAtom, class TileShape, class Stride>
|
|
|
|
|
static constexpr
|
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
|
auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) {
|
|
|
|
|
if constexpr (not cute::is_layout<Stride>::value) {
|
|
|
|
|
return tile_to_shape(
|
|
|
|
|
layout_atom,
|
|
|
|
|
append(tile_shape, Int<DispatchPolicy::Stages>{}),
|
|
|
|
|
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{});
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
auto gmem_tile = composition(stride, tile_shape);
|
|
|
|
|
return make_layout_like(append(gmem_tile, make_layout(Int<DispatchPolicy::Stages>{}, 0)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using SmemLayoutA = decltype(get_smem_layout(InternalSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalStrideA{}));
|
|
|
|
|
using SmemLayoutB = decltype(get_smem_layout(InternalSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalStrideB{}));
|
|
|
|
|
|
|
|
|
|
// It is assumed that the scales and zero-points share the same smem layout
|
|
|
|
|
using SmemLayoutScale = decltype(tile_to_shape(
|
|
|
|
|
@ -381,6 +393,18 @@ public:
|
|
|
|
|
uint32_t mma_promotion_interval = 4;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<class Shape, class Stride>
|
|
|
|
|
static constexpr
|
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
|
auto get_gmem_layout(Shape const& shape, Stride const& stride) {
|
|
|
|
|
if constexpr (not cute::is_layout<Stride>::value) {
|
|
|
|
|
return make_layout(shape, stride);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
return stride;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Device side kernel params
|
|
|
|
|
struct Params {
|
|
|
|
|
private:
|
|
|
|
|
@ -394,10 +418,14 @@ public:
|
|
|
|
|
TransformB_>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
|
|
// Assumption: StrideA is congruent with Problem_MK
|
|
|
|
|
using LayoutA = decltype(get_gmem_layout(repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}));
|
|
|
|
|
using LayoutB = decltype(get_gmem_layout(repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}));
|
|
|
|
|
|
|
|
|
|
using TMA_A = decltype(make_tma_copy_A_sm90<TmaElementA>(
|
|
|
|
|
GmemTiledCopyA{},
|
|
|
|
|
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}),
|
|
|
|
|
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementA const*>(nullptr)), LayoutA{}),
|
|
|
|
|
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
|
|
|
|
TileShape{},
|
|
|
|
|
ClusterShape{})); // mcast along N mode for this M load, if any
|
|
|
|
|
@ -419,7 +447,7 @@ public:
|
|
|
|
|
// Assumption: StrideB is congruent with Problem_NK
|
|
|
|
|
using TMA_B = decltype(make_tma_copy_B_sm90(
|
|
|
|
|
GmemTiledCopyB{},
|
|
|
|
|
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}),
|
|
|
|
|
make_tensor(Outer::get_logical_ptr(static_cast<InternalElementB const*>(nullptr)), LayoutB{}),
|
|
|
|
|
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
|
|
|
|
TileShape{},
|
|
|
|
|
ClusterShape{})); // mcast along M mode for this N load, if any
|
|
|
|
|
@ -431,6 +459,8 @@ public:
|
|
|
|
|
int group_size;
|
|
|
|
|
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
|
|
|
|
int reload_factor = (group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
|
|
|
|
|
InternalStrideA dA;
|
|
|
|
|
InternalStrideB dB;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
@ -469,8 +499,8 @@ public:
|
|
|
|
|
dB = args.dA;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA));
|
|
|
|
|
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB));
|
|
|
|
|
Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), get_gmem_layout(make_shape(M,K,L), dA));
|
|
|
|
|
Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), get_gmem_layout(make_shape(N,K,L), dB));
|
|
|
|
|
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90<TmaElementA>(
|
|
|
|
|
GmemTiledCopyA{},
|
|
|
|
|
tensor_a,
|
|
|
|
|
@ -490,7 +520,7 @@ public:
|
|
|
|
|
|
|
|
|
|
uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
|
|
|
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1 };
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, tma_transaction_bytes, 1, dA, dB };
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (ModeHasScales) {
|
|
|
|
|
auto scale_k = (K + args.group_size - 1) / args.group_size;
|
|
|
|
|
@ -505,7 +535,7 @@ public:
|
|
|
|
|
_1{}); // mcast along N mode for this M load, if any
|
|
|
|
|
|
|
|
|
|
if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) {
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
|
|
|
|
|
}
|
|
|
|
|
else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
|
|
|
|
Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS));
|
|
|
|
|
@ -515,7 +545,7 @@ public:
|
|
|
|
|
SmemLayoutScale{}(_,_,cute::Int<0>{}),
|
|
|
|
|
ScaleTileShape{},
|
|
|
|
|
_1{}); // mcast along N mode for this M load, if any
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}) };
|
|
|
|
|
return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, tma_transaction_bytes + TmaTransactionBytesExtra, (args.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}), dA, dB };
|
|
|
|
|
} else {
|
|
|
|
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in to_underlying_arguments.");
|
|
|
|
|
}
|
|
|
|
|
@ -533,33 +563,37 @@ public:
|
|
|
|
|
constexpr int tma_alignment_bits = 128;
|
|
|
|
|
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
|
|
|
|
auto [M,N,K,L] = problem_shape_MNKL;
|
|
|
|
|
|
|
|
|
|
bool implementable = true;
|
|
|
|
|
|
|
|
|
|
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
|
|
|
|
bool check_aligned_A = cutlass::detail::check_alignment<min_tma_aligned_elements_A>(get_gmem_layout(cute::make_shape(M,K,L), args.dA));
|
|
|
|
|
|
|
|
|
|
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
|
|
|
|
bool check_aligned_B = cutlass::detail::check_alignment<min_tma_aligned_elements_B>(get_gmem_layout(cute::make_shape(N,K,L), args.dB));
|
|
|
|
|
|
|
|
|
|
bool check_aligned_S = true;
|
|
|
|
|
bool check_aligned_Z = true;
|
|
|
|
|
bool check_mode_args = true;
|
|
|
|
|
|
|
|
|
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
|
|
|
|
implementable = implementable && (args.ptr_S == nullptr);
|
|
|
|
|
implementable = implementable && (args.ptr_Z == nullptr);
|
|
|
|
|
check_mode_args = check_mode_args && (args.ptr_S == nullptr);
|
|
|
|
|
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (ModeHasScales) {
|
|
|
|
|
const int scale_mn = SwapAB ? N : M;
|
|
|
|
|
const int scale_k = (K + args.group_size - 1) / args.group_size;
|
|
|
|
|
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
|
|
|
|
|
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
|
|
|
|
|
implementable = implementable && args.group_size != 0;
|
|
|
|
|
implementable = implementable && (args.ptr_S != nullptr);
|
|
|
|
|
check_aligned_S = cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), args.dS);
|
|
|
|
|
check_mode_args = check_mode_args && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
|
|
|
|
|
check_mode_args = check_mode_args && args.group_size != 0;
|
|
|
|
|
check_mode_args = check_mode_args && (args.ptr_S != nullptr);
|
|
|
|
|
|
|
|
|
|
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
|
|
|
|
|
implementable = implementable && (args.ptr_Z == nullptr);
|
|
|
|
|
check_mode_args = check_mode_args && (args.ptr_Z == nullptr);
|
|
|
|
|
}
|
|
|
|
|
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
|
|
|
|
|
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
|
|
|
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
|
|
|
|
|
implementable = implementable && (args.ptr_Z != nullptr);
|
|
|
|
|
check_aligned_Z = cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), args.dS);
|
|
|
|
|
check_mode_args = check_mode_args && (args.ptr_Z != nullptr);
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
|
|
|
|
|
@ -569,10 +603,23 @@ public:
|
|
|
|
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!implementable) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
if (!check_mode_args) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n");
|
|
|
|
|
}
|
|
|
|
|
return implementable;
|
|
|
|
|
if (!check_aligned_A) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
}
|
|
|
|
|
if (!check_aligned_B) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
}
|
|
|
|
|
if (!check_aligned_S) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
}
|
|
|
|
|
if (!check_aligned_Z) {
|
|
|
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
|
|
|
|
@ -618,8 +665,8 @@ public:
|
|
|
|
|
|
|
|
|
|
// TMA requires special handling of strides to deal with coord codomain mapping
|
|
|
|
|
// Represent the full tensors -- get these from TMA
|
|
|
|
|
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
|
|
|
|
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
|
|
|
|
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(get_gmem_layout(make_shape(M,K,L), mainloop_params.dA))); // (m,k,l)
|
|
|
|
|
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(get_gmem_layout(make_shape(N,K,L), mainloop_params.dB))); // (n,k,l)
|
|
|
|
|
|
|
|
|
|
// Make tiled views, defer the slice
|
|
|
|
|
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
|
|
|
|
@ -680,8 +727,6 @@ public:
|
|
|
|
|
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int lane_predicate = cute::elect_one_sync();
|
|
|
|
|
|
|
|
|
|
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
|
|
|
|
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
|
|
|
|
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
|
|
|
|
|
@ -748,8 +793,10 @@ public:
|
|
|
|
|
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
|
|
|
|
|
|
|
|
|
int write_stage = smem_pipe_write.index();
|
|
|
|
|
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
|
|
|
|
if (cute::elect_one_sync()) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
|
|
|
|
if (cute::elect_one_sync()) {
|
|
|
|
|
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
|
|
|
|
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
|
|
|
|
|
// Nothing extra to do.
|
|
|
|
|
@ -920,6 +967,12 @@ public:
|
|
|
|
|
// Unroll the K mode manually to set scale D to 1
|
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
|
|
|
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
|
|
|
|
|
if (k_block < K_BLOCK_MAX - 2) { // prefetch next block
|
|
|
|
|
copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view,
|
|
|
|
|
partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage);
|
|
|
|
|
@ -927,11 +980,6 @@ public:
|
|
|
|
|
if (k_block < K_BLOCK_MAX - 1) {
|
|
|
|
|
transform_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1);
|
|
|
|
|
}
|
|
|
|
|
warpgroup_arrive();
|
|
|
|
|
// (V,M) x (V,N) => (V,M,N)
|
|
|
|
|
cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum);
|
|
|
|
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
|
|
|
|
warpgroup_commit_batch();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
--k_tile_count;
|
|
|
|
|
|