CUTLASS 3.1 (#915)

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
This commit is contained in:
ANIKET SHIVAM
2023-04-14 20:19:34 -07:00
committed by GitHub
parent 9b8166e3f0
commit d572cc1aab
482 changed files with 37184 additions and 16419 deletions

View File

@ -171,7 +171,7 @@ copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType))
{
/* @pre is_aligned<N>(src.data()) &&
* is_aligned<N>(dst.data())
@ -259,4 +259,51 @@ copy(Copy_Atom<DefaultCopy, CopyArgs...> const&,
return copy(src, dst);
}
//////////////////////////////////////////
// Special Auto-Vectorizing Overloads
//////////////////////////////////////////
#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
template <class... CT_Args, class... CA_Args,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const& atom,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
static_assert(sizeof_bits<SrcType>::value == sizeof_bits<DstType>::value);
static_assert((is_gmem<SrcEngine>::value && is_smem<DstEngine>::value) ||
(is_smem<SrcEngine>::value && is_gmem<DstEngine>::value),
"Bulk Copy only supports gmem -> smem or smem -> gmem movement.");
// Do BulkCopy dispatch
using BULK_COPY_OP = conditional_t<is_gmem<SrcEngine>::value,
SM90_BULK_COPY_G2S,
SM90_BULK_COPY_S2G>;
constexpr int N = decltype(max_common_vector(src, dst))::value;
// Construct a new concrete Atom of the vector size
using N_BITS = Int<N*sizeof_bits<SrcType>::value>;
using COPY_ATOM = Copy_Atom<Copy_Traits<BULK_COPY_OP, N_BITS, CT_Args...>, SrcType>;
auto bulk_atom = apply(atom.opargs_, [&](auto const&... args) { return COPY_ATOM{args...}; });
// Tile the src and dst to the Atom
auto tiler = right_inverse(dst.layout()).compose(Int<N>{});
#if 0
if (thread0()) {
print("copy -- found a max_common_vector of %d\n", N);
print(" "); print(src.data()); print(" o "); print(layout(src)); print("\n");
print(" "); print(dst.data()); print(" o "); print(layout(dst)); print("\n");
}
#endif
return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler));
}
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
} // end namespace cute