@ -171,7 +171,7 @@ copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
|
||||
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
|
||||
{
|
||||
/* @pre is_aligned<N>(src.data()) &&
|
||||
* is_aligned<N>(dst.data())
|
||||
@ -259,4 +259,51 @@ copy(Copy_Atom<DefaultCopy, CopyArgs...> const&,
|
||||
return copy(src, dst);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////
|
||||
// Special Auto-Vectorizing Overloads
|
||||
//////////////////////////////////////////
|
||||
|
||||
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
template <class... CT_Args, class... CA_Args,
|
||||
class SrcEngine, class SrcLayout,
|
||||
class DstEngine, class DstLayout>
|
||||
CUTE_HOST_DEVICE
|
||||
void
|
||||
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
|
||||
Tensor<SrcEngine, SrcLayout> const& src,
|
||||
Tensor<DstEngine, DstLayout> & dst)
|
||||
{
|
||||
using SrcType = typename SrcEngine::value_type;
|
||||
using DstType = typename DstEngine::value_type;
|
||||
static_assert(sizeof_bits<SrcType>::value == sizeof_bits<DstType>::value);
|
||||
static_assert((is_gmem<SrcEngine>::value && is_smem<DstEngine>::value) ||
|
||||
(is_smem<SrcEngine>::value && is_gmem<DstEngine>::value),
|
||||
"Bulk Copy only supports gmem -> smem or smem -> gmem movement.");
|
||||
// Do BulkCopy dispatch
|
||||
using BULK_COPY_OP = conditional_t<is_gmem<SrcEngine>::value,
|
||||
SM90_BULK_COPY_G2S,
|
||||
SM90_BULK_COPY_S2G>;
|
||||
|
||||
constexpr int N = decltype(max_common_vector(src, dst))::value;
|
||||
|
||||
// Construct a new concrete Atom of the vector size
|
||||
using N_BITS = Int<N*sizeof_bits<SrcType>::value>;
|
||||
using COPY_ATOM = Copy_Atom<Copy_Traits<BULK_COPY_OP, N_BITS, CT_Args...>, SrcType>;
|
||||
auto bulk_atom = apply(atom.opargs_, [&](auto const&... args) { return COPY_ATOM{args...}; });
|
||||
|
||||
// Tile the src and dst to the Atom
|
||||
auto tiler = right_inverse(dst.layout()).compose(Int<N>{});
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found a max_common_vector of %d\n", N);
|
||||
print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n");
|
||||
print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler));
|
||||
}
|
||||
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
|
||||
|
||||
} // end namespace cute
|
||||
|
||||
Reference in New Issue
Block a user