Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@ -71,85 +71,103 @@ cooperative_copy(uint32_t const& tid,
|
||||
|
||||
// Precondition on tid in DEBUG
|
||||
assert(tid < NumThreads);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
// Fallback - slow path, naive copy, vectorization disabled
|
||||
if constexpr(size(SrcLayout{}) % NumThreads != 0) {
|
||||
int index = static_cast<int>(tid);
|
||||
CUTE_UNROLL
|
||||
for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) {
|
||||
if(index < size(SrcLayout{})) {
|
||||
dst[index] = src[index];
|
||||
}
|
||||
index += NumThreads;
|
||||
}
|
||||
} else {
|
||||
// Fast path with vectorization
|
||||
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
// Precondition on pointer alignment in DEBUG
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(src.data())));
|
||||
assert(is_byte_aligned<max(MaxVecBits/8, 1u)>(raw_pointer_cast(dst.data())));
|
||||
constexpr int elem_bits = sizeof_bits_v<typename SrcEngine::value_type>;
|
||||
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
//
|
||||
// Determine val+thr vectorization based on src/dst size and number of threads
|
||||
// NOTE: This heuristic promotes parallelization over vectorization
|
||||
//
|
||||
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
// The number of elements that can be vectorized in values
|
||||
constexpr int common_elem = decltype(max_common_vector(src, dst))::value;
|
||||
constexpr int common_bits = common_elem * elem_bits;
|
||||
constexpr int total_elem = decltype(size(src))::value;
|
||||
constexpr int total_bits = total_elem * elem_bits;
|
||||
static_assert(total_bits % NumThreads == 0);
|
||||
constexpr int total_bits_per_thr = total_bits / NumThreads;
|
||||
// If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits
|
||||
constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr);
|
||||
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
// Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits
|
||||
constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast<int>(MaxVecBits));
|
||||
// Convert back to number of elements, safe_div
|
||||
static_assert((vec_bits % elem_bits) == 0);
|
||||
constexpr int vec_elem = vec_bits / elem_bits;
|
||||
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
// Use only part of threads if there's not enough work for all threads
|
||||
constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0)
|
||||
? NumThreads
|
||||
: (total_elem / vec_elem);
|
||||
static_assert(vec_thrs <= NumThreads);
|
||||
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
// The common layout of the two tensors that can be vectorized over threads
|
||||
// vidx -> coord
|
||||
auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()),
|
||||
get_nonswizzle_portion(dst.layout()));
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
// Scale up the common_layout to cover the entire tensors
|
||||
// vidx -> coord
|
||||
auto full_perm = tile_to_shape(make_layout(common_layout), size(src));
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
// Create the Tiler
|
||||
// ((vid,tid),iter)
|
||||
auto layout_vt = logical_divide(full_perm, Layout<Shape<Int<vec_elem>, Int<vec_thrs>>>{});
|
||||
|
||||
// Apply and slice
|
||||
Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_);
|
||||
Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_);
|
||||
|
||||
// Should account for vec_bits < 8 and/or vec_elem <= 1
|
||||
// And also account for subbyte types, which could cause race conditions
|
||||
// Want to ENFORCE sufficient vectorization in those cases
|
||||
static_assert((vec_bits >= 8), "No support for subbyte copying");
|
||||
using VecType = uint_bit_t<vec_bits>;
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
if (thread0()) {
|
||||
print(" "); print("cooperative_copy -- vec\n");
|
||||
print(" "); print("NumThreads: "); print(NumThreads); print("\n");
|
||||
print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n");
|
||||
print(" "); print("src: "); print(src); print("\n");
|
||||
print(" "); print("dst: "); print(dst); print("\n");
|
||||
print(" "); print("common_layout: "); print(common_layout); print("\n");
|
||||
print(" "); print("full_perm: "); print(full_perm); print("\n");
|
||||
print(" "); print("Used vector: "); print(vec_elem); print("\n");
|
||||
print(" "); print("Used threads: "); print(vec_thrs); print("\n");
|
||||
print(" "); print("layout_vt: "); print(layout_vt); print("\n");
|
||||
print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n");
|
||||
print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n");
|
||||
print(" "); print("src_v: "); print(src_v); print("\n");
|
||||
print(" "); print("dst_v: "); print(dst_v); print("\n");
|
||||
print(" "); print("recast<VecType const>(src_v): "); print(recast<VecType const>(src_v)); print("\n");
|
||||
print(" "); print("recast<VecType const>(dst_v): "); print(recast<VecType const>(dst_v)); print("\n");
|
||||
}
|
||||
#ifdef __CUDA_ARCH__
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
// If we're using all threads (static) or the tid is in in-range (dynamic)
|
||||
if (vec_thrs >= NumThreads or tid < vec_thrs) {
|
||||
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user