[EVT] Add support for Row/Col broadcast PtrArray (#2033)
* Add group support to EVT row/col broadcast. * small modifications --------- Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
@ -972,14 +972,20 @@ compute_row_broadcast_stages() {
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class ElementInput,
|
||||
class ElementCompute = ElementInput,
|
||||
class ElementInput_,
|
||||
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
|
||||
class StrideMNL_ = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<ElementInput>,
|
||||
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
|
||||
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
|
||||
>
|
||||
struct Sm90RowBroadcast {
|
||||
using StrideMNL = StrideMNL_;
|
||||
// Get base element input type.
|
||||
using ElementInput = cute::remove_pointer_t<ElementInput_>;
|
||||
// Check if input is an array of pointers.
|
||||
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
|
||||
using PtrRowType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;
|
||||
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining");
|
||||
|
||||
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<1>(StrideMNL{}))>, bool>; // row vector or scalar broadcast
|
||||
@ -991,7 +997,7 @@ struct Sm90RowBroadcast {
|
||||
};
|
||||
|
||||
struct Arguments {
|
||||
ElementInput const* ptr_row = nullptr;
|
||||
PtrRowType ptr_row = nullptr;
|
||||
ElementInput null_default = ElementInput(0);
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
@ -1036,7 +1042,7 @@ struct Sm90RowBroadcast {
|
||||
is_zero_ = params.null_default == ElementCompute(0);
|
||||
}
|
||||
// Dynamic non-batched scalar broadcast
|
||||
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
|
||||
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
|
||||
is_zero_ = params.ptr_row[0] == ElementInput(0);
|
||||
}
|
||||
}
|
||||
@ -1183,7 +1189,13 @@ struct Sm90RowBroadcast {
|
||||
|
||||
auto layout_M = make_layout(M, repeat_like(M, _0{}));
|
||||
auto layout_L = make_layout(L, get<2>(params.dRow));
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L));
|
||||
ElementInput const* ptr_row;
|
||||
if constexpr(IsArrayOfPointers) {
|
||||
ptr_row = params.ptr_row[l];
|
||||
} else {
|
||||
ptr_row = params.ptr_row;
|
||||
}
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L));
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
@ -1220,14 +1232,20 @@ struct Sm90RowBroadcast {
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class ElementInput,
|
||||
class ElementCompute = ElementInput,
|
||||
class ElementInput_,
|
||||
class ElementCompute = cute::remove_pointer_t<ElementInput_>,
|
||||
class StrideMNL_ = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<ElementInput>,
|
||||
int Alignment = 128 / sizeof_bits_v<cute::remove_pointer_t<ElementInput_>>,
|
||||
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
|
||||
>
|
||||
struct Sm90ColBroadcast {
|
||||
using StrideMNL = StrideMNL_;
|
||||
// Get base element input type.
|
||||
using ElementInput = cute::remove_pointer_t<ElementInput_>;
|
||||
// Check if input is an array of pointers.
|
||||
static constexpr bool IsArrayOfPointers = is_same_v<ElementInput*, ElementInput_>;
|
||||
using PtrColType = cute::conditional_t<IsArrayOfPointers, ElementInput const* const*, ElementInput const*>;
|
||||
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining");
|
||||
|
||||
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<0>(StrideMNL{}))>, bool>; // Column vector or scalar broadcast
|
||||
@ -1238,13 +1256,13 @@ struct Sm90ColBroadcast {
|
||||
struct SharedStorage { };
|
||||
|
||||
struct Arguments {
|
||||
ElementInput const* ptr_col = nullptr;
|
||||
PtrColType ptr_col = nullptr;
|
||||
ElementInput null_default = ElementInput(0);
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
struct Params {
|
||||
ElementInput const* ptr_col = nullptr;
|
||||
PtrColType ptr_col = nullptr;
|
||||
ElementCompute null_default = ElementCompute(0);
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
@ -1301,7 +1319,7 @@ struct Sm90ColBroadcast {
|
||||
is_zero_ = params.null_default == ElementCompute(0);
|
||||
}
|
||||
// Dynamic non-batched scalar broadcast
|
||||
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
|
||||
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0) && !IsArrayOfPointers) {
|
||||
is_zero_ = params.ptr_col[0] == ElementInput(0);
|
||||
}
|
||||
}
|
||||
@ -1398,6 +1416,7 @@ struct Sm90ColBroadcast {
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE {
|
||||
auto shape_M = get<0>(args.problem_shape_mnkl);
|
||||
if constexpr (IsDynamicBroadcast) {
|
||||
@ -1416,11 +1435,17 @@ struct Sm90ColBroadcast {
|
||||
|
||||
auto layout_N = make_layout(N, repeat_like(N, _0{}));
|
||||
auto layout_L = make_layout(L, get<2>(params.dCol));
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L));
|
||||
ElementInput const* ptr_col;
|
||||
if constexpr(IsArrayOfPointers) {
|
||||
ptr_col = params.ptr_col[l];
|
||||
} else {
|
||||
ptr_col = params.ptr_col;
|
||||
}
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L));
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L));
|
||||
Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L));
|
||||
Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
Reference in New Issue
Block a user