CUTLASS 3.3.0 (#1167)
* Release 3.3.0 Adds support for mixed precision GEMMs On Hopper and Ampere Adds support for < 16B aligned GEMMs on Hopper Enhancements to EVT Enhancements to Python interface Enhancements to Sub-byte type handling in CuTe Several other bug-fixes and performance improvements. * minor doc update
This commit is contained in:
@ -48,11 +48,21 @@ struct UniversalCopy
|
||||
using SRegisters = S[1];
|
||||
using DRegisters = D[1];
|
||||
|
||||
template<class S_, class D_>
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
copy(S const& src,
|
||||
D & dst)
|
||||
copy(S_ const& src,
|
||||
D_ & dst)
|
||||
{
|
||||
dst = src;
|
||||
dst = static_cast<D>(static_cast<S>(src));
|
||||
}
|
||||
|
||||
// Accept mutable temporaries
|
||||
template<class S_, class D_>
|
||||
CUTE_HOST_DEVICE static constexpr void
|
||||
copy(S_ const& src,
|
||||
D_ && dst)
|
||||
{
|
||||
copy(src, dst);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -96,6 +96,66 @@ struct SM80_CP_ASYNC_CACHEGLOBAL
|
||||
}
|
||||
};
|
||||
|
||||
/// Copy via cp.async with caching at all levels
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst,
|
||||
bool pred)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
int src_size = pred ? sizeof(TS) : 0;
|
||||
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)),
|
||||
"r"(src_size));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/// Copy via cp.async with caching at global level
|
||||
template <class TS, class TD = TS>
|
||||
struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL
|
||||
{
|
||||
using SRegisters = TS[1];
|
||||
using DRegisters = TD[1];
|
||||
|
||||
static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)");
|
||||
static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported");
|
||||
|
||||
CUTE_HOST_DEVICE static void
|
||||
copy(TS const& gmem_src,
|
||||
TD & smem_dst,
|
||||
bool pred)
|
||||
{
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
TS const* gmem_ptr = &gmem_src;
|
||||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst);
|
||||
int src_size = pred ? sizeof(TS) : 0;
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n"
|
||||
:: "r"(smem_int_ptr),
|
||||
"l"(gmem_ptr),
|
||||
"n"(sizeof(TS)),
|
||||
"r"(src_size));
|
||||
#else
|
||||
CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
|
||||
@ -785,7 +785,7 @@ tma_store_arrive() {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Wait on prior N (Count) TMA_STORE instructions to complete
|
||||
// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete
|
||||
template <int Count>
|
||||
CUTE_HOST_DEVICE static void
|
||||
tma_store_wait() {
|
||||
|
||||
Reference in New Issue
Block a user