3.6.0 update (#2005)

* 3.6.0 update

* doc and swap stuff

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Yujia Zhai
2024-12-24 22:34:40 -08:00
committed by GitHub
parent e1cd8c7866
commit 3d261a5974
258 changed files with 10863 additions and 3883 deletions

View File

@ -51,19 +51,14 @@ naive_cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
auto N = size(src);
if (tid < N) {
uint32_t upper_bound = (N / NumThreads) * NumThreads;
CUTE_UNROLL
for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds
dst[tid + i] = src[tid + i];
}
if (N % NumThreads != 0) { // Likely static condition
uint32_t final_idx = tid + upper_bound;
if (final_idx < N) { // Final in-bounds
dst[final_idx] = src[final_idx];
}
}
auto N = size(dst);
auto R = N % Int<NumThreads>{};
if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds
dst[tid] = src[tid];
}
CUTE_UNROLL
for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds
dst[tid + i] = src[tid + i];
}
}
@ -117,12 +112,14 @@ heuristic_permutation(Tensor<AEngine, ALayout> const& a,
//
template <uint32_t NumThreads, uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
class DstEngine, class DstLayout,
class CopyPolicy = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
Tensor<DstEngine, DstLayout> & dst,
CopyPolicy const& cpy = {})
{
// Assumes the shapes are static, can generalize/fallback
CUTE_STATIC_ASSERT_V(is_static<decltype(shape(src))>{} && is_static<decltype(shape(dst))>{});
@ -283,23 +280,28 @@ cooperative_copy(uint32_t const& tid,
// If we're using all threads (static) or the tid is in-range (dynamic)
if (vec_thrs == NumThreads or tid < vec_thrs) {
return copy_if(TrivialPredTensor{}, recast<VecType const>(src_v), recast<VecType>(dst_v));
auto src_c = recast<VecType const>(src_v);
auto dst_c = recast<VecType>(dst_v);
return copy(cpy, src_c, dst_c);
}
}
}
// Default max-vectorization size to value_type size
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
class DstEngine, class DstLayout,
class CopyPolicy = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
Tensor<DstEngine, DstLayout> & dst,
CopyPolicy const& cpy = {})
{
constexpr uint32_t MaxVecBits = sizeof_bits_v<typename SrcEngine::value_type>;
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst, cpy);
}
//
@ -308,26 +310,30 @@ cooperative_copy(uint32_t const& tid,
template <uint32_t NumThreads,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
class DstEngine, class DstLayout,
class CopyPolicy = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
Tensor<DstEngine, DstLayout> && dst,
CopyPolicy const& cpy = {})
{
return cooperative_copy<NumThreads>(tid, src, dst);
return cooperative_copy<NumThreads>(tid, src, dst, cpy);
}
template <uint32_t NumThreads, uint32_t MaxVecBits,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
class DstEngine, class DstLayout,
class CopyPolicy = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_copy(uint32_t const& tid,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
Tensor<DstEngine, DstLayout> && dst,
CopyPolicy const& cpy = {})
{
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst);
return cooperative_copy<NumThreads, MaxVecBits>(tid, src, dst, cpy);
}
} // end namespace cute

View File

@ -50,31 +50,115 @@ namespace cute
namespace detail {
// Predicated Cooperative GEMM
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
// Slow fallback path:
template<typename ... Args,
typename Alpha, typename TRC, typename RCLayout,
typename Beta, class TSC, typename CLayout, typename SCLayout,
typename CLoadTransformOp, typename CStoreTransformOp>
CUTE_HOST_DEVICE
void
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
epilogue_predication(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TRC, RCLayout> & tCrC,
Beta const& beta,
Tensor<TSC, CLayout> & sC,
Tensor<TSC, SCLayout> & tCsC,
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
{
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
using InputTypeC = typename TSC::value_type;
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v<ComputeTypeC, typename TRC::value_type>);
// Create coordinate tensors for the problem
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
// Repeat partitioning with thr_mma
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
const bool isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
CUTE_GCC_UNREACHABLE;
} ();
// Custom axpby_if for now
CUTE_UNROLL
for (int i = 0; i < size(tCrC); ++i)
{
if (elem_less(tCcC(i), shape(sC)))
{
tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i)
: alpha * tCrC(i) +
beta * static_cast<ComputeTypeC>(sC_load_op(tCsC(i))));
}
}
}
template<class Alpha, class TRC, class RCLayout,
class Beta, class TSC, class SCLayout,
class CLoadTransformOp, class CStoreTransformOp,
class SmemCopyOpC>
CUTE_HOST_DEVICE
void
epilogue_no_predication(Alpha const& alpha,
Tensor<TRC, RCLayout> & tCrC,
Beta const& beta,
Tensor<TSC, SCLayout> & tCsC,
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C
SmemCopyOpC const& sC_copy_op)
{
using InputTypeC = typename TSC::value_type;
using ComputeTypeC = typename TRC::value_type;
const bool isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
CUTE_GCC_UNREACHABLE;
} ();
Tensor tCrDi = make_fragment_like(tCsC);
Tensor tCrD = make_fragment_like(tCrC);
if(!isBetaZero) {
copy(sC_copy_op, tCsC, tCrDi);
// Transform C on/after load
cute::transform(tCrDi, tCrD, sC_load_op);
}
// C = alpha * (A * B) + beta * C
axpby(alpha, tCrC, beta, tCrD);
// Transform C before/on store
cute::transform(tCrD, tCrDi, sC_store_op);
copy(sC_copy_op, tCrDi, tCsC);
}
// Predicated Cooperative GEMM
template <class... Args,
class TA, class ALayout, class TB, class BLayout,
class TC, class RCLayout,
class ALoadTransformOp, class BLoadTransformOp>
CUTE_HOST_DEVICE
void
cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Tensor<TC, RCLayout> & tCrC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM
{
using InputTypeA = typename TA::value_type;
using InputTypeB = typename TB::value_type;
using InputTypeC = typename TC::value_type;
using ComputeTypeA = typename ThrMMA<Args...>::ValTypeA;
using ComputeTypeB = typename ThrMMA<Args...>::ValTypeB;
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
//
// MMA Partitioning
@ -83,22 +167,18 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
// Partition the sA, sB, and sC tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
#if 0
if (thread0()) {
print(" sA: "); print( sA); print("\n");
print(" sB: "); print( sB); print("\n");
print(" sC: "); print( sC); print("\n");
print(thr_mma);
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
@ -154,23 +234,20 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast<ComputeTypeA>(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast<ComputeTypeB>(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{};
}
}
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
CUTE_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
{
@ -185,138 +262,80 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast<ComputeTypeA>(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{};
}
}
CUTE_UNROLL
for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N
CUTE_UNROLL
for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast<ComputeTypeB>(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{};
}
}
}
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
// Create coordinate tensors for the problem
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
// Repeat partitioning with thr_mma
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
const bool isBetaZero = (beta == Beta{});
// Custom axpby_if for now
CUTE_UNROLL
for (int i = 0; i < size(tCrC); ++i)
{
if (elem_less(tCcC(i), shape(sC)))
{
tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast<TypeC>(tCrC(i))
: alpha * static_cast<TypeC>(tCrC(i)) +
beta * static_cast<TypeC>(sC_load_op(tCsC(i))));
}
}
}
// Slow fallback path
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm_predication(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
{
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op);
}
// Unpredicated Cooperative GEMM
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
template <class... Args,
class TA, class ALayout, class TB, class BLayout,
class TC, class CLayout,
class ALoadTransformOp, class BLoadTransformOp,
class CLoadTransformOp, class CStoreTransformOp,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
class SmemCopyOpA, class SmemCopyOpB>
CUTE_HOST_DEVICE
void
cooperative_gemm_no_predication(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
cooperative_gemm_no_predication(uint32_t thread_idx,
ThrMMA<Args...> const& thr_mma,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Tensor<TC, CLayout> & tCrC,
ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM
SmemCopyOpA const& sA_copy_op,
SmemCopyOpB const& sB_copy_op)
{
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
using InputTypeA = typename TA::value_type;
using InputTypeB = typename TB::value_type;
using InputTypeC = typename TC::value_type;
using ComputeTypeA = typename ThrMMA<Args...>::ValTypeA;
using ComputeTypeB = typename ThrMMA<Args...>::ValTypeB;
using ComputeTypeC = typename ThrMMA<Args...>::ValTypeC;
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
//
// MMA Partitioning
//
Tensor tCsC = thr_mma.partition_C(sC);
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
using CopyOpAType = SmemCopyOpA;
using CopyOpBType = SmemCopyOpB;
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, TypeA>{}, thr_mma);
auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, InputTypeA>{}, thr_mma);
auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx);
Tensor tCsA = smem_thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K
Tensor tCrAi = make_fragment_like(tCsA);
Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, TypeB>{}, thr_mma);
auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, InputTypeB>{}, thr_mma);
auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx);
Tensor tCsB = smem_thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K
Tensor tCrBi = make_fragment_like(tCsB);
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
#if 0
if (thread0()) {
print(" sA: "); print(sA); print("\n");
print(" sB: "); print(sB); print("\n");
print(" sC: "); print(sC); print("\n");
print(thr_mma); print("\n");
print("tCsC: "); print(tCsC); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
@ -333,15 +352,12 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
// PREFETCH
//
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{}));
copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{}));
//
// MAINLOOP
//
// Clear accumulators
clear(tCrC);
constexpr int K_BLOCK_MAX = size<2>(tCrA);
CUTE_UNROLL
@ -352,132 +368,178 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
{
// Load the next k_block
int k_next = k_block + 1; // statically unrolled
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next));
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next));
copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next));
copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next));
}
// Transform A and B, relying on the compiler to remove in case of identity ops
cute::transform(tCrA(_,_,k_block), sA_load_op);
cute::transform(tCrB(_,_,k_block), sB_load_op);
cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op);
cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op);
// GEMM on k_block in registers
gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
}
//
// Epilogue
//
auto isBetaZero = [&] () {
if constexpr (is_complex<Beta>::value) {
return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
}
else {
return beta == Int<0>{};
}
CUTE_GCC_UNREACHABLE;
} ();
using CopyOpCType = SmemCopyOpC;
Tensor tCrD = thr_mma.make_fragment_C(tCsC);
if(!isBetaZero) {
copy(CopyOpCType{}, tCsC, tCrD);
// Transform C on/after load
cute::transform(tCrD, sC_load_op);
}
// C = alpha * (A * B) + beta * C
axpby(alpha, tCrC, beta, tCrD);
// Transform C before/on store
cute::transform(tCrD, sC_store_op);
copy(CopyOpCType{}, tCrD, tCsC);
}
} // end namespace detail
template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
{
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using TypeA = typename TA::value_type;
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
static_assert(is_convertible_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
"ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
"BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<CLoadTransformOp, TypeC>>, TypeC>,
"CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
static_assert(is_convertible_v<decay_t<invoke_result_t<CStoreTransformOp, TypeC>>, TypeC>,
"CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
tile_shape(TiledMMA<Args...>{}));
if constexpr (compat) {
detail::cooperative_gemm_no_predication<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
} else {
detail::cooperative_gemm_predication(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
}
}
// C passed as a shared memory tensor
// Epilogue included
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
class SmemCopyOpC = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Beta const& beta,
Tensor<TC, CLayout> & sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyOpC const& sC_copy_op = {})
{
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{});
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using InputTypeA = typename TA::value_type;
using InputTypeB = typename TB::value_type;
using InputTypeC = typename TC::value_type;
using ComputeTypeA = typename TiledMMA<Args...>::ValTypeA;
using ComputeTypeB = typename TiledMMA<Args...>::ValTypeB;
using ComputeTypeC = typename TiledMMA<Args...>::ValTypeC;
auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
tile_shape(TiledMMA<Args...>{}));
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC
// Clear accumulators
clear(tCrC);
#if 0
if (thread0()) {
print(" sC: "); print(sC); print("\n");
print(" tCsC: "); print(tCsC); print("\n");
}
#endif
if constexpr (is_constant<true, decltype(compat)>::value) {
detail::cooperative_gemm_no_predication(
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
);
detail::epilogue_no_predication(
alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op
);
} else {
detail::cooperative_gemm_predication(
thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op
);
detail::epilogue_predication(
thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op
);
}
}
// C already partitioned into registers on input
// It can be passed non-empty
// Epilogue not included
template <class... Args,
class TA, class ALayout, class TB, class BLayout,
class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Tensor<TC, CLayout> & tCrC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {})
{
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
using InputTypeA = typename TA::value_type;
using InputTypeB = typename TB::value_type;
using InputTypeC = typename TC::value_type;
using ComputeTypeA = typename TiledMMA<Args...>::ValTypeA;
using ComputeTypeB = typename TiledMMA<Args...>::ValTypeB;
using ComputeTypeC = typename TiledMMA<Args...>::ValTypeC;
// Check if input C fragment is compatible with thr_mma and problem size
using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB))));
CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC)));
auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
tile_shape(TiledMMA<Args...>{}));
// ThrMMA
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
if constexpr (is_constant<true, decltype(compat)>::value) {
detail::cooperative_gemm_no_predication(
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
);
} else {
detail::cooperative_gemm_predication(
thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op
);
}
}
// Accept mutable temporaries
template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy,
class SmemCopyOpC = DefaultCopy>
CUTE_HOST_DEVICE
void
cooperative_gemm(uint32_t thread_idx,
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
TiledMMA<Args...> const& tiled_mma,
Alpha const& alpha,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Beta const& beta,
Tensor<TC, CLayout> && sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C
SmemCopyOpA const& sA_copy_op = {},
SmemCopyOpB const& sB_copy_op = {},
SmemCopyOpC const& sC_copy_op = {})
{
using CopyOpA = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TA::value_type>>;
using CopyOpB = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TB::value_type>>;
using CopyOpC = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TC::value_type>>;
cooperative_gemm<CopyOpA, CopyOpB, CopyOpC>(
thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
);
cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op,
sA_copy_op, sB_copy_op, sC_copy_op);
}
// Legacy overload of cute::gemm for backwards-compatibility
@ -485,27 +547,38 @@ template <class... Args,
class Alpha, class TA, class ALayout, class TB, class BLayout,
class Beta, class TC, class CLayout,
class ALoadTransformOp = cute::identity, class BLoadTransformOp = cute::identity,
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
__CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
BLayout::rank == 2 && is_smem<TB>::value &&
CLayout::rank == 2 && is_smem<TC>::value)>
class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity>
CUTE_HOST_DEVICE
void
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> sA,
Tensor<TB, BLayout> sB,
Beta const& beta,
Tensor<TC, CLayout> sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
gemm(ThrMMA<Args...> const& thr_mma,
Alpha const& alpha,
Tensor<TA, ALayout> const& sA,
Tensor<TB, BLayout> const& sB,
Beta const& beta,
Tensor<TC, CLayout> & sC,
ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM
BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM
CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM
CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
{
CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{});
CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{});
CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM
CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN
CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N)
// Goes directly to the slow path to avoid getting thread_idx from thr_mma
detail::cooperative_gemm_predication(
thr_mma, alpha, sA, sB, beta, sC,
sA_load_op, sB_load_op, sC_load_op, sC_store_op
thr_mma, sA, sB, sC, sA_load_op, sB_load_op
);
detail::epilogue_predication(
thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op
);
}

View File

@ -38,79 +38,6 @@
namespace cute
{
//
// Accept mutable temporaries
//
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(src, dst);
}
template <class VecType,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_vec<VecType>(src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_aligned(src, dst);
}
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(pred, src, dst);
}
template <class CopyPolicy,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(CopyPolicy const& copy_policy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(copy_policy, pred, src, dst);
}
template <class CopyPolicy,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(CopyPolicy const& copy_policy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(copy_policy, src, dst);
}
//
// copy_if -- Predicated Copy
//
@ -124,12 +51,13 @@ copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
auto copy_op = select_elementwise_copy(src, dst);
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
CUTE_UNROLL
for (int i = 0; i < size(src); ++i) {
for (int i = 0; i < size(dst); ++i) {
if (pred(i)) {
copy_op.copy(src(i), dst(i));
dst(i) = static_cast<DstType>(static_cast<SrcType>(src(i)));
}
}
}
@ -138,17 +66,6 @@ copy_if(PrdTensor const& pred,
// copy_if -- Predicated CopyAtom
//
namespace detail {
// Trait that detects if atom's traits has a member function with(bool)
template <class, class Enable = void>
constexpr bool has_with_bool = false;
template <class T>
constexpr bool has_with_bool<T, cute::void_t<decltype(declval<typename T::Traits>().with(declval<bool>()))>> = true;
} // end namespace detail
template <class... CopyArgs,
class PredTensor,
class SrcEngine, class SrcLayout,
@ -161,73 +78,90 @@ copy_if(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
copy_atom.call(src, dst);
if constexpr (has_with_bool) {
copy_atom.with(pred()).call(src, dst);
} else {
if (pred()) { copy_atom.call(src, dst); }
}
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
CUTE_UNROLL
for (int i = 0; i < size<1>(src_v); ++i) {
// If copy traits can be transformed with a predicate value, do it, otherwise branch here
if constexpr (detail::has_with_bool<Copy_Atom<CopyArgs...>>) {
for (int i = 0; i < size<1>(dst_v); ++i) {
if constexpr (has_with_bool) {
copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i));
} else {
if (pred(i)) {
copy_atom.call(src_v(_,i), dst_v(_,i));
}
if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); }
}
}
}
}
//
// copy_vec -- attempt vectorized copy with VecType
// copy_if -- AutoCopyAsync
//
template <class VecType,
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_vec(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
copy_if(AutoCopyAsync const& cpy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(sizeof_bits_v<VecType> >= 8 && sizeof_bits_v<VecType> % 8 == 0,
"Expected a vectorization type of at least a byte.");
using SrcElemWithConst = remove_reference_t<typename SrcEngine::reference>;
using SrcType = typename SrcEngine::value_type;
using DstType = typename DstEngine::value_type;
if constexpr (cute::is_same<SrcType, DstType>::value &&
sizeof_bits_v<VecType> > sizeof_bits_v<DstType>)
{
// Preserve volatility of Src/Dst types.
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
#if 0
if (thread0()) {
print("copy_vec<%db> -- vectorizing copy:\n", int(sizeof_bits_v<VecType>));
print(" "); print(src); print(" => "); print(src_v); print("\n");
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
auto copy_op = []() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
if constexpr (is_gmem<SrcEngine>::value && is_smem<DstEngine>::value &&
sizeof(SrcType) == sizeof(DstType)) {
if constexpr (is_const_v<SrcElemWithConst> && sizeof(SrcType) == 16) {
return SM80_CP_ASYNC_CACHEGLOBAL<SrcType,DstType>{};
} else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) {
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
} else {
return UniversalCopy<SrcType,DstType>{};
}
} else {
return UniversalCopy<SrcType,DstType>{};
}
#endif
return copy_if(TrivialPredTensor{}, src_v, dst_v);
} else {
#if 0
if (thread0()) {
print("copy_vec<%db> -- NOT vectorizing copy:\n", int(sizeof_bits_v<VecType>));
print(" "); print(src); print("\n");
print(" "); print(dst); print("\n");
}
CUTE_GCC_UNREACHABLE;
#else
return UniversalCopy<SrcType,DstType>{};
#endif
}();
return copy_if(TrivialPredTensor{}, src, dst);
CUTE_UNROLL
for (int i = 0; i < size(dst); ++i) {
if (pred(i)) {
copy_op.copy(src(i), dst(i));
}
}
}
//
// copy -- AutoCopyAsync
//
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(AutoCopyAsync const& cpy,
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
copy_if(cpy, TrivialPredTensor{}, src, dst);
}
//
// copy -- CopyAtom
//
@ -238,15 +172,56 @@ template <class... CopyArgs,
CUTE_HOST_DEVICE
void
copy(Copy_Atom<CopyArgs...> const& copy_atom,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
Tensor<SrcEngine, SrcLayout> const& src, // (V,Rest...)
Tensor<DstEngine, DstLayout> & dst) // (V,Rest...)
{
return copy_if(copy_atom, TrivialPredTensor{}, src, dst);
static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch.");
if constexpr (SrcLayout::rank == 1) { // Dispatch the copy
copy_atom.call(src, dst);
} else { // Loop over all but the first mode
constexpr int R = SrcLayout::rank;
Tensor src_v = group_modes<1,R>(src);
Tensor dst_v = group_modes<1,R>(dst);
if constexpr (is_static<decltype(shape(src_v))>::value && is_static<decltype(shape(dst_v))>::value) {
CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v));
// AutoFilter on the Rest-mode
auto dst_null = nullspace(layout<1>(dst_v));
Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest))
Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest))
CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n));
CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error");
CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy");
CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{}));
CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{}));
Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest)
CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c));
CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst));
CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src));
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_c); ++i) {
copy_atom.call(src_c(_,i), dst_c(_,i));
}
} else {
CUTE_UNROLL
for (int i = 0; i < size<1>(dst_v); ++i) {
copy_atom.call(src_v(_,i), dst_v(_,i));
}
}
}
}
//////////////////////////////////////////
// Special Auto-Vectorizing Overloads
//////////////////////////////////////////
////////////////////////////////////////////////////////
// Special Auto-Vectorizing, Auto-Filtering Overloads //
////////////////////////////////////////////////////////
// Specialization for AutoVectorizingCopyAssumedAlignment<MaxVecBits>
template <int MaxVecBits, class... Args,
@ -258,30 +233,67 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
constexpr int vec_elem = decltype(max_common_vector(src, dst))::value;
constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst));
constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int<MaxVecBits>{}));
static_assert(is_integral<decltype(Int<common_elem>{} * sizeof_bits_v<typename SrcEngine::value_type>)>::value, "Error: Attempting a subbit copy!");
constexpr int vec_bits = gcd(common_elem * sizeof_bits_v<typename SrcEngine::value_type>, align_bits);
constexpr int max_align_src = decltype(max_alignment(src.layout()))::value;
constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value;
constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst);
if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) {
// If more than one element vectorizes to 8bits or more, then recast and copy
using VecType = uint_bit_t<vec_bits>;
// Preserve volatility
using SrcVecType = conditional_t<is_volatile_v<typename SrcEngine::element_type>, VecType const volatile, VecType const>;
using DstVecType = conditional_t<is_volatile_v<typename DstEngine::element_type>, VecType volatile, VecType >;
constexpr int src_bits = sizeof_bits<typename SrcEngine::value_type>::value;
constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits);
// Recast
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
if constexpr (vec_elem > 1 && vec_bits >= 8) {
// If more than one element vectorizes to 8bits or more, then copy_vec
#if 0
if (thread0()) {
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits);
print(" "); print(src); print("\n");
print(" "); print(dst); print("\n");
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits);
print(" "); print(src); print(" => "); print(src_v); print("\n");
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
}
#endif
return copy_vec<uint_bit_t<vec_bits>>(src, dst);
return copy_if(TrivialPredTensor{}, src_v, dst_v);
} else {
return copy_if(TrivialPredTensor{}, src, dst);
}
}
template <class Base>
struct AutoFilter {
Base const& base;
CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {}
};
// Specialization for AutoFilter
template <class CopyOp,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(AutoFilter<CopyOp> const& copy_op,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
if constexpr (is_constant<true, decltype(size(src) == size(dst))>::value) {
auto dst_null = nullspace(dst.layout());
Tensor dst_n = zipped_divide(dst, dst_null);
Tensor src_n = zipped_divide(src, dst_null);
CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error");
CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy");
copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_));
} else {
copy(copy_op.base, src, dst);
}
}
// Auto-vectorizing copy for static layouts
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
@ -292,7 +304,11 @@ copy(Tensor<SrcEngine, SrcLayout> const& src,
{
if constexpr (is_static<SrcLayout>::value && is_static<DstLayout>::value) {
// Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
} else
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
// Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned.
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst);
} else {
// Do not assume that dynamic layouts are aligned.
return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst);
@ -307,7 +323,12 @@ void
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
if constexpr (is_static<decltype(shape(src))>::value && is_static<decltype(shape(dst))>::value) {
// Tensors with static shapes can be filtered
return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst);
} else {
return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst);
}
}
// Specializaton for Atom AutoVectorizingCopyAssumedAlignment
@ -379,4 +400,146 @@ copy(Copy_Atom<Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...>, CA_Args...> const&
}
#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)
//
// Decay TiledCopy to CopyAtom
//
template <class CopyAtom, class TV, class Tiler,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy_if(static_cast<CopyAtom const&>(tiled_copy), pred, src, dst);
}
template <class CopyAtom, class TV, class Tiler,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(TiledCopy<CopyAtom, TV, Tiler> const& tiled_copy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
return copy(static_cast<CopyAtom const&>(tiled_copy), src, dst);
}
template <class TiledCopy, class ThrIdx,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) = delete;
template <class TiledCopy, class ThrIdx,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(ThrCopy<TiledCopy, ThrIdx> const& thr_copy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst) = delete;
//
// Catch uncaught policies
//
template <class CopyPolicy,
class PredTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(CopyPolicy const& cpy,
PredTensor const& prd,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
}
template <class CopyPolicy,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(CopyPolicy const& cpy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(dependent_false<CopyPolicy>, "Unrecognized CopyPolicy.");
}
//
// Accept mutable temporaries
//
template <class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(pred, src, dst);
}
template <class CopyPolicy,
class PrdTensor,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_if(CopyPolicy const& copy_policy,
PrdTensor const& pred,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_if(copy_policy, pred, src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(src, dst);
}
template <class CopyPolicy,
class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy(CopyPolicy const& copy_policy,
Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy(copy_policy, src, dst);
}
template <class SrcEngine, class SrcLayout,
class DstEngine, class DstLayout>
CUTE_HOST_DEVICE
void
copy_aligned(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> && dst)
{
return copy_aligned(src, dst);
}
} // end namespace cute

View File

@ -39,7 +39,7 @@ namespace cute
{
//
// Direct Copy for any type
// Direct Copy for any specific types
//
template <class S, class D = S>
@ -48,21 +48,15 @@ struct UniversalCopy
using SRegisters = S[1];
using DRegisters = D[1];
template <class S_, class D_>
CUTE_HOST_DEVICE static constexpr void
copy(S_ const& src,
D_ & dst)
{
dst = static_cast<D>(static_cast<S>(src));
}
// Sanity
static_assert(sizeof_bits_v<S> >= 8);
static_assert(sizeof_bits_v<D> >= 8);
// Accept mutable temporaries
template <class S_, class D_>
CUTE_HOST_DEVICE static constexpr void
copy(S_ const& src,
D_ && dst)
copy(S const& src,
D & dst)
{
UniversalCopy<S,D>::copy(src, dst);
dst = src;
}
};
@ -92,6 +86,12 @@ using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>;
using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>;
//
// Copy policy automatically selecting between
// UniversalCopy and cp.async , based on type and memory space.
//
struct AutoCopyAsync {};
//
// Global memory prefetch into L2
//

View File

@ -2040,6 +2040,103 @@ struct SM80_16x8x64_S32U4U4S32_TN_SATURATE
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 8x8x128 TN
struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1},"
"{%2},"
"{%3},"
"{%4, %5};\n"
: "=r"(d0), "=r"(d1)
: "r"(a0),
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x128 TN
struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x256 TN
struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 8x8x128 TN
@ -2141,103 +2238,4 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 8x8x128 TN
struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[2];
using ARegisters = uint32_t[1];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[2];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1,
uint32_t const& a0,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1},"
"{%2},"
"{%3},"
"{%4, %5};\n"
: "=r"(d0), "=r"(d1)
: "r"(a0),
"r"(b0),
"r"(c0), "r"(c1));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x128 TN
struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[2];
using BRegisters = uint32_t[1];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1,
uint32_t const& b0,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5},"
"{%6},"
"{%7, %8, %9, %10};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1),
"r"(b0),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// MMA 16x8x256 TN
struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC
{
using DRegisters = uint32_t[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = uint32_t[4];
CUTE_HOST_DEVICE static void
fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3,
uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3,
uint32_t const& b0, uint32_t const& b1,
uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3)
{
#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3),
"r"(b0), "r"(b1),
"r"(c0), "r"(c1), "r"(c2), "r"(c3));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED");
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cute

View File

@ -100,16 +100,16 @@ struct Copy_Atom<Copy_Traits<Args...>, CopyInternalType>
if constexpr (is_constant<NumValSrc, decltype(size(src))>::value ||
is_constant<NumValDst, decltype(size(dst))>::value) {
// Dispatch to unpack to execute instruction
return copy_unpack(*this, src, dst);
} else
if constexpr (is_tuple<decltype(shape(src))>::value &&
is_tuple<decltype(shape(dst))>::value) {
return copy_unpack(static_cast<Traits const&>(*this), src, dst);
} else if constexpr (is_tuple<decltype(shape(src))>::value &&
is_tuple<decltype(shape(dst))>::value) {
// If the size of the src/dst doesn't match the instruction,
// recurse this rank-1 layout by peeling off the mode
// ((A,B,C,...)) -> (A,B,C,...)
return copy(*this, tensor<0>(src), tensor<0>(dst));
} else {
static_assert(dependent_false<SEngine>, "No instruction match and no recursion possible.");
static_assert(dependent_false<SEngine>,
"CopyAtom: Src/Dst partitioning does not match the instruction requirement.");
}
}

View File

@ -92,23 +92,29 @@ struct Copy_Traits<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>>
using RefLayout = SrcLayout;
};
// Extract a CPY_Op from a CPY_Traits
template <class CPY_Traits>
struct CPY_Op {};
template <class CPY_Op_Arg, class... Args>
struct CPY_Op<Copy_Traits<CPY_Op_Arg, Args...>> {
using type = CPY_Op_Arg;
};
//
// Generic copy_unpack for common argument-based Copy_Traits
//
template <class CopyOp, class... Args,
template <class AnyCPYTraits,
class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE constexpr
void
copy_unpack(Copy_Traits<CopyOp,Args...> const&,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> & dst)
copy_unpack(AnyCPYTraits const&,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> & dst)
{
// Specializations can generalize on these checks
//static_assert(is_smem<TS>::value, "Expected smem for this Copy_Traits<CopyOp>");
//static_assert(is_rmem<TD>::value, "Expected rmem for this Copy_Traits<CopyOp>");
using CopyOp = typename CPY_Op<AnyCPYTraits>::type;
using RegistersSrc = typename CopyOp::SRegisters;
using RegistersDst = typename CopyOp::DRegisters;
using RegTypeSrc = typename remove_extent<RegistersSrc>::type;
@ -129,18 +135,15 @@ copy_unpack(Copy_Traits<CopyOp,Args...> const&,
rD, make_int_sequence<RegNumDst>{});
}
//
// Accept mutable temporaries
//
template <class CopyOp, class... Args,
template <class AnyCPYTraits,
class SEngine, class SLayout,
class DEngine, class DLayout>
CUTE_HOST_DEVICE constexpr
void
copy_unpack(Copy_Traits<CopyOp,Args...> const& traits,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> && dst)
copy_unpack(AnyCPYTraits const& traits,
Tensor<SEngine,SLayout> const& src,
Tensor<DEngine,DLayout> && dst)
{
copy_unpack(traits, src, dst);
}

View File

@ -51,13 +51,6 @@ struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS<S,D>>
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Construct a zfill variant with a given predicate value
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS_ZFILL<S,D>>
with(bool pred) const {
return {pred};
}
};
template <class S, class D>
@ -73,13 +66,6 @@ struct Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL<S,D>>
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Construct a zfill variant with a given predicate value
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<S,D>>
with(bool pred) const {
return {pred};
}
};
template <class S, class D>
@ -96,8 +82,15 @@ struct Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS_ZFILL<S,D>>
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Predicate value that determines whether to load or zfill
bool pred = false;
// Predicate value: true = load, false = zfill
bool pred = true;
// Construct a zfill variant with a given predicate value
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM80_CP_ASYNC_CACHEALWAYS_ZFILL<S,D>>
with(bool pred) const {
return {pred};
}
// Overload copy_unpack for zfill variant to pass the predicate into the op
template <class TS, class SLayout,
@ -137,8 +130,15 @@ struct Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<S,D>>
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// Predicate value that determines whether to load or zfill
bool pred = false;
// Predicate value: true = load, false = zfill
bool pred = true;
// Construct a zfill variant with a given predicate value
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<S,D>>
with(bool pred) const {
return {pred};
}
// Overload copy_unpack for zfill variant to pass the predicate into the op
template <class TS, class SLayout,
@ -164,31 +164,4 @@ struct Copy_Traits<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<S,D>>
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Element copy selector
template <class SrcTensor, class DstTensor>
CUTE_HOST_DEVICE constexpr
auto
select_elementwise_copy(SrcTensor const&, DstTensor const&)
{
using SrcType = typename SrcTensor::value_type;
using DstType = typename DstTensor::value_type;
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
if constexpr (is_gmem<SrcTensor>::value && is_smem<DstTensor>::value &&
sizeof(SrcType) == sizeof(DstType) &&
(sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16))
{
return SM80_CP_ASYNC_CACHEALWAYS<SrcType,DstType>{};
} else {
return UniversalCopy<SrcType,DstType>{};
}
CUTE_GCC_UNREACHABLE;
#else
return UniversalCopy<SrcType,DstType>{};
#endif
}
}
} // end namespace cute

View File

@ -58,37 +58,31 @@ struct AuxTmaParams {
};
// Utility for unpacking TMA_LOAD arguments into a CopyOp
template <class CopyOp>
template <class CopyOp, class... Args>
struct TMA_LOAD_Unpack
{
template <class... Args,
class TS, class SLayout,
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits<CopyOp, Args...> const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TD>::value, "SM90_TMA_LOAD requires the destination be shared memory.");
auto src_coord = src.data().coord_;
if constexpr (detail::is_prefetch<CopyOp>) {
return detail::explode_tuple(detail::CallCOPY<CopyOp>{},
traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
src_coord, tuple_seq<decltype(src_coord)>{});
} else {
static_assert(is_smem<TD>::value, "SM90_TMA_LOAD requires the destination be shared memory.");
void* dst_ptr = cute::raw_pointer_cast(dst.data());
void* dst_ptr = cute::raw_pointer_cast(dst.data());
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr);
#endif
return detail::explode_tuple(detail::CallCOPY<CopyOp>{},
traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
make_tuple(dst_ptr), seq<0>{},
src_coord, tuple_seq<decltype(src_coord)>{});
}
return detail::explode_tuple(detail::CallCOPY<CopyOp>{},
traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
make_tuple(dst_ptr), seq<0>{},
src_coord, tuple_seq<decltype(src_coord)>{});
}
};
@ -131,7 +125,7 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
@ -143,7 +137,7 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)}};
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
@ -167,7 +161,7 @@ struct Copy_Traits<SM90_TMA_LOAD, NumBitsPerTMA, AuxParams_>
// The executable SM90_TMA_LOAD with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_OP>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
@ -183,12 +177,15 @@ struct Copy_Traits<SM90_TMA_LOAD_OP, NumBitsPerTMA>
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
// The prefetch for SM90_TMA_LOAD with tma_desc
template <class NumBitsPerTMA, class... Args>
struct Copy_Traits<SM90_TMA_LOAD::PREFETCH, NumBitsPerTMA, Args...>
: TMA_LOAD_Unpack<SM90_TMA_LOAD::PREFETCH>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
@ -206,6 +203,19 @@ struct Copy_Traits<SM90_TMA_LOAD::PREFETCH, NumBitsPerTMA, Args...>
CUTE_HOST_DEVICE
Copy_Traits(Copy_Traits<CopyArgs...> const& traits)
: opargs_({&traits.tma_desc_}) {}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
auto src_coord = src.data().coord_;
return detail::explode_tuple(detail::CallCOPY<SM90_TMA_LOAD::PREFETCH>{},
traits.opargs_, tuple_seq<decltype(traits.opargs_)>{},
src_coord, tuple_seq<decltype(src_coord)>{});
}
};
//////////////////////////////////////////////////////////////////////////////
@ -246,7 +256,7 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
uint64_t& tma_load_mbar,
uint16_t const& multicast_mask,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
@ -257,7 +267,7 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
uint64_t& tma_load_mbar,
uint16_t const& multicast_mask,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)}};
return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
@ -281,7 +291,7 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST, NumBitsPerTMA, AuxParams_>
// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_OP>
: TMA_LOAD_Unpack<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
@ -298,43 +308,17 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_OP, NumBitsPerTMA>
uint16_t, // multicast mask
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint)
: opargs_(desc, mbar, mask, hint) {}
};
//////////////////////////////////////////////////////////////////////////////
///////////////////////////// TMA_STORE //////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////
// Utility for unpacking TMA_STORE arguments into a CopyOp
template <class CopyOp>
struct TMA_STORE_Unpack
{
template <class... Args,
class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits<CopyOp, Args...> const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
void const* const desc_ptr = traits.tma_desc_;
void const* const src_ptr = cute::raw_pointer_cast(src.data());
auto dst_coord = dst.data().coord_;
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
return detail::explode_tuple(detail::CallCOPY<SM90_TMA_STORE>{},
make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
};
struct SM90_TMA_STORE_OP : SM90_TMA_STORE {};
struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {};
// The executable SM90_TMA_STORE with tma_desc
template <class NumBitsPerTMA, class AuxParams_>
@ -369,6 +353,13 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_STORE_PTR, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc) const {
return {new_tma_desc};
}
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
@ -393,19 +384,11 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_>
make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
// Construct Copy_Traits executable (w/ swapped out TMA descriptor) for SM90_TMA_STORE (for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_STORE_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc) const {
return {{}, new_tma_desc};
}
};
// The executable SM90_TMA_STORE with tma_desc
// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead
template <class NumBitsPerTMA>
struct Copy_Traits<SM90_TMA_STORE_OP, NumBitsPerTMA>
: TMA_STORE_Unpack<SM90_TMA_STORE_OP>
struct Copy_Traits<SM90_TMA_STORE_PTR, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
@ -417,6 +400,31 @@ struct Copy_Traits<SM90_TMA_STORE_OP, NumBitsPerTMA>
// SM90_TMA_STORE arguments
TmaDescriptor const* tma_desc_;
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst)
{
static_assert(is_smem<TS>::value, "Expected smem src for SM90_TMA_STORE");
//static_assert(is_gmem<TD>::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor
void const* const desc_ptr = traits.tma_desc_;
void const* const src_ptr = cute::raw_pointer_cast(src.data());
auto dst_coord = dst.data().coord_;
#if 0
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0);
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n",
threadIdx.x, threadIdx.y, threadIdx.z,
blockIdx.x, blockIdx.y, blockIdx.z,
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr);
#endif
return detail::explode_tuple(detail::CallCOPY<SM90_TMA_STORE_PTR>{},
make_tuple(desc_ptr, src_ptr), seq<0,1>{},
dst_coord, tuple_seq<decltype(dst_coord)>{});
}
};
//////////////////////////////////////////////////////////////////////////////
@ -520,7 +528,7 @@ struct Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, OpArgs...>
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_G2S, NumBitsPerTMA, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
return {&bulk_mbar};
}
template <class TS, class SLayout,
@ -613,7 +621,7 @@ struct Copy_Traits<SM90_BULK_COPY_AUTO, OpArgs...>
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_BULK_COPY_AUTO, uint64_t*>
with(uint64_t& bulk_mbar) const {
return {{&bulk_mbar}};
return {&bulk_mbar};
}
};
@ -1391,19 +1399,46 @@ tma_partition(Copy_Atom<Args...> const& copy_atom,
return cute::make_tuple(gresult, sresult);
}
// Explicit defaults for cta_coord and cta_layout
template <class... Args,
class SEngine, class SLayout,
class GEngine, class GLayout>
CUTE_DEVICE
auto
tma_partition(Copy_Atom<Args...> const& copy_atom,
Tensor<SEngine,SLayout> const& stensor, // SMEM Tensor (TMATile, Rest...)
Tensor<GEngine,GLayout> const& gtensor) // GMEM Tensor (TMATile, Rest...)
{
return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor);
}
// TMA Multicast Masks Calculation
template <int Mode, class CtaLayout, class CtaCoord>
CUTE_HOST_DEVICE constexpr
auto
uint16_t
create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk,
CtaCoord const& cta_coord_vmnk)
{
auto cta_coord_slicer = replace<Mode>(cta_coord_vmnk, _);
auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk);
// Get the instruction code
uint16_t mcast_mask = 0;
for (int i = 0; i < size(cta_layout); ++i) {
mcast_mask |= uint16_t(1) << cta_layout(i);
if constexpr (rank_v<decltype(cta_layout)> == 1 and depth_v<decltype(cta_layout)> <= 1 and
not is_static<decltype(cta_layout)>::value) {
// Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout
mcast_mask = uint16_t(1);
// Smear by stride<0> (may want to predicate on stride<0> mag?)
mcast_mask |= mcast_mask << (1*stride<0>(cta_layout));
mcast_mask |= mcast_mask << (2*stride<0>(cta_layout));
mcast_mask |= mcast_mask << (4*stride<0>(cta_layout));
mcast_mask |= mcast_mask << (8*stride<0>(cta_layout));
// Select shape<0>
mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout)));
} else {
// Get the instruction code -- generic path
for (int i = 0; i < size(cta_layout); ++i) {
mcast_mask |= uint16_t(1) << cta_layout(i);
}
}
// Shift by the instruction's elected block rank (dynamic)
mcast_mask <<= elected_cta;

View File

@ -250,12 +250,12 @@ struct TiledMMA : MMA_Atom
auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN)
// Tile the tensor for the Atom
auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})),
auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})),
make_layout(size<1>(AtomShape_MNK{})));
auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN))
auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN))
// Transform the Atom mode from (M,K) to (Thr,Val)
auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN))
auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN))
// Tile the tensor for the C-threads
auto thr_tile = make_tile(_,
@ -604,16 +604,15 @@ CUTE_HOST_DEVICE constexpr
auto
partition_shape_C(TiledMMA<Args...> const& mma, Shape_MN const& shape_MN)
{
constexpr int R = rank_v<Shape_MN>;
static_assert(R >= 2, "Must have at least rank-2");
auto atomMNK = typename TiledMMA<Args...>::AtomShape_MNK{};
auto thrVMNK = typename TiledMMA<Args...>::ThrLayoutVMNK{};
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutC_TV{});
auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK));
auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK));
return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN));
auto dummy = make_layout(shape(shape_MN));
auto dummy_tv = mma.thrfrg_C(dummy);
// Slice+rearrange like partition_C
auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat<rank(dummy)>(_)));
return shape(dummy_v);
}
template <class... Args, class Shape_MN>
CUTE_HOST_DEVICE constexpr
auto
@ -632,14 +631,12 @@ CUTE_HOST_DEVICE constexpr
auto
partition_shape_A(TiledMMA<Args...> const& mma, Shape_MK const& shape_MK)
{
constexpr int R = rank_v<Shape_MK>;
static_assert(R >= 2, "Must have at least rank-2");
auto atomMNK = typename TiledMMA<Args...>::AtomShape_MNK{};
auto thrVMNK = typename TiledMMA<Args...>::ThrLayoutVMNK{};
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutA_TV{});
auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK));
auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK));
return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK));
auto dummy = make_layout(shape(shape_MK));
auto dummy_tv = mma.thrfrg_A(dummy);
// Slice+rearrange like partition_A
auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat<rank(dummy)>(_)));
return shape(dummy_v);
}
template <class... Args, class Shape_NK>
@ -647,14 +644,12 @@ CUTE_HOST_DEVICE constexpr
auto
partition_shape_B(TiledMMA<Args...> const& mma, Shape_NK const& shape_NK)
{
constexpr int R = rank_v<Shape_NK>;
static_assert(R >= 2, "Must have at least rank-2");
auto atomMNK = typename TiledMMA<Args...>::AtomShape_MNK{};
auto thrVMNK = typename TiledMMA<Args...>::ThrLayoutVMNK{};
auto V = shape<1>(typename TiledMMA<Args...>::AtomLayoutB_TV{});
auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK));
auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK));
return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK));
auto dummy = make_layout(shape(shape_NK));
auto dummy_tv = mma.thrfrg_B(dummy);
// Slice+rearrange like partition_B
auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat<rank(dummy)>(_)));
return shape(dummy_v);
}
//

View File

@ -419,6 +419,203 @@ template <>
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U8U8S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s4 * s4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_8, _8, _32>;
using ThrID = Layout<_32>;
// (T32,V8) -> (M8,N32)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_64, _1>, Stride<_8>>>;
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_64, _1>, Stride<_8>>>;
using CLayout = SM80_8x8_Row;
};
template <>
struct MMA_Traits<SM80_8x8x32_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
// (T32,V16) -> (M16,N32)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape < _8, _2>>,
Stride<Stride<_128, _1>, Stride<_16, _8>>>;
// (T32,V8) -> (M8,N32)
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_32, _1>, Stride<_8>>>;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16, _8, _64>;
using ThrID = Layout<_32>;
// (T32,V32) -> (M16,N64)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape < _8, _2, _2>>,
Stride<Stride<_128, _1>, Stride<_16, _8, _512>>>;
// (T32,V16) -> (M8,N64)
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8, _2>>,
Stride<Stride<_64, _1>, Stride<_8, _256>>>;
using CLayout = SM80_16x8_Row;
};
template <>
struct MMA_Traits<SM80_16x8x64_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s4 * u4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x32_S32S4U4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x32_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32S4U4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32S4U4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S4U4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x64_S32S4U4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x64_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32S4U4S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u4 * s4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x32_S32U4S4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x32_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32U4S4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32U4S4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U4S4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x64_S32U4S4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x64_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32U4S4S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u4 * u4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_8x8x32_S32U4U4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_8x8x32_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32U4U4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x32_S32U4U4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x32_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U4U4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x64_S32U4U4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};
template <>
struct MMA_Traits<SM80_16x8x64_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32U4U4S32_TN> {};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = b1 ^ b1 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
@ -440,9 +637,13 @@ struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC>
using CLayout = SM80_16x8_Row;
};
///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = b1 & b1 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
template <>
struct MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_ANDPOPC>
:MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> {};
: MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> {};
template<>
struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC>
@ -455,7 +656,7 @@ struct MMA_Traits<SM80_8x8x128_S32U1U1S32_TN_XORPOPC>
using Shape_MNK = Shape<_8,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
Stride<Stride<_256,_1>,_8>>;
using BLayout = Layout<Shape<Shape<_4,_8>,_32>,
Stride<Stride<_256,_1>,_8>>;
using CLayout = SM80_8x8_Row;
@ -472,7 +673,7 @@ struct MMA_Traits<SM80_16x8x128_S32U1U1S32_TN_XORPOPC>
using ValTypeA = cute::uint1b_t;
using ValTypeB = cute::uint1b_t;
using ValTypeC = int32_t;
using Shape_MNK = Shape<_16,_8,_128>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4,_8>,Shape<_32,_2>>,

View File

@ -1128,7 +1128,6 @@ struct MMA_Traits<SM90_64x32x16_F32F16F16_RS<tnspA, tnspB, scaleA, scaleB>>
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
GMMA::Major tnspA,
GMMA::Major tnspB,

View File

@ -7735,4 +7735,4 @@ struct MMA_Traits<SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN<scaleA, s
#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
#include "mma_traits_sm90_gmma_sparse_ext.hpp"
#endif
#endif

View File

@ -100,7 +100,7 @@
#if defined(_MSC_VER)
// Provides support for alternative operators 'and', 'or', and 'not'
# include <iso646.h>
# include <ciso646>
#endif // _MSC_VER
#if defined(__CUDACC_RTC__)

View File

@ -100,20 +100,30 @@ public:
// Copy Ctor
CUTE_HOST_DEVICE constexpr
subbyte_reference(subbyte_reference const& other) {
*this = element_type(other);
subbyte_reference(subbyte_reference<value_type> const& other) {
*this = other.get();
}
CUTE_HOST_DEVICE constexpr
subbyte_reference(subbyte_reference<value_type const> const& other) {
*this = other.get();
}
// Copy Assignment
CUTE_HOST_DEVICE constexpr
subbyte_reference& operator=(subbyte_reference const& other) {
return *this = element_type(other);
subbyte_reference& operator=(subbyte_reference<value_type> const& other) {
return *this = other.get();
}
CUTE_HOST_DEVICE constexpr
subbyte_reference& operator=(subbyte_reference<value_type const> const& other) {
return *this = other.get();
}
// Assignment
template <class T_ = element_type>
CUTE_HOST_DEVICE constexpr
enable_if_t<!is_const_v<T_>, subbyte_reference&> operator=(element_type x)
enable_if_t<!is_const_v<T_>, subbyte_reference&> operator=(value_type x)
{
static_assert(is_same_v<T_, element_type>, "Do not specify template arguments!");
storage_type item = (reinterpret_cast<storage_type const&>(x) & BitMask);
@ -149,11 +159,11 @@ public:
// Value
CUTE_HOST_DEVICE
element_type get() const
value_type get() const
{
if constexpr (is_same_v<bool, value_type>) { // Extract to bool -- potentially faster impl
return bool((*ptr_) & (BitMask << idx_));
} else { // Extract to element_type
} else { // Extract to value_type
// Extract from the current storage element
auto item = storage_type((ptr_[0] >> idx_) & BitMask);
@ -165,13 +175,13 @@ public:
item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits);
}
return reinterpret_cast<element_type&>(item);
return reinterpret_cast<value_type&>(item);
}
}
// Extract to type element_type
// Extract to type value_type
CUTE_HOST_DEVICE constexpr
operator element_type() const {
operator value_type() const {
return get();
}
@ -341,6 +351,14 @@ recast_ptr(subbyte_iterator<T> const& x) {
CUTE_GCC_UNREACHABLE;
}
// Dynamic pointers have unknown static alignment
template <class T>
CUTE_HOST_DEVICE constexpr
Int<0>
max_alignment(subbyte_iterator<T> const& x) {
return {};
}
template <class T>
CUTE_HOST_DEVICE void
print(subbyte_iterator<T> const& x) {
@ -352,6 +370,7 @@ CUTE_HOST_DEVICE void
print(subbyte_reference<T> const& x) {
print(x.get());
}
//
// array_subbyte
// Statically sized array for non-byte-aligned data types

View File

@ -1830,7 +1830,7 @@ recast_layout(Layout<Shape,Stride> const& layout)
return upcast<scale::num>(layout);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
return downcast<scale::den>(upcast<scale::num>(layout));
}
CUTE_GCC_UNREACHABLE;

View File

@ -616,7 +616,7 @@ recast_layout(ComposedLayout<A,O,B> const& layout)
return upcast<scale::num>(layout);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
return downcast<scale::den>(upcast<scale::num>(layout));
}
CUTE_GCC_UNREACHABLE;
}
@ -631,6 +631,15 @@ max_alignment(ComposedLayout<A,O,B> const& layout)
return Int<1>{};
}
template <class A, class O, class B>
CUTE_HOST_DEVICE constexpr
auto
nullspace(ComposedLayout<A,O,B> const& layout)
{
// Do not attempt for general ComposedLayouts
return Layout<_1,_0>{};
}
//
// Display utilities
//

View File

@ -154,13 +154,6 @@ operator*(C<c>, R<a,b>) {
return {};
}
template <auto c, auto a, auto b>
CUTE_HOST_DEVICE constexpr
typename R<c*b,a>::type
operator/(C<c>, R<a,b>) {
return {};
}
// Product with dynamic type needs to produce an integer...
template <class C, auto a, auto b,
__CUTE_REQUIRES(cute::is_std_integral<C>::value)>
@ -179,6 +172,13 @@ operator*(R<a,b>, C const& c) {
return c * R<a,b>::num / R<a,b>::den;
}
template <class C, auto a, auto b>
CUTE_HOST_DEVICE constexpr
auto
operator/(C const& c, R<a,b>) {
return c * R<b,a>{};
}
template <auto a, auto b, auto x, auto y>
CUTE_HOST_DEVICE constexpr
typename R<a*y+b*x, b*y>::type
@ -200,6 +200,10 @@ operator+(C<c>, R<a,b>) {
return {};
}
/////////////////
// Comparisons //
/////////////////
template <auto a, auto b, auto x, auto y>
CUTE_HOST_DEVICE constexpr
bool_constant<R<a,b>::num == R<x,y>::num && R<a,b>::den == R<x,y>::den>
@ -221,6 +225,31 @@ operator==(C<c>, R<a,b>) {
return {};
}
///////////////////////
// Special functions //
///////////////////////
template <auto a, auto b, auto x, auto y>
CUTE_HOST_DEVICE constexpr
typename R<gcd(a*y,b*x),b*x>::type
gcd(R<a,b>, R<x,y>) {
return {};
}
template <auto a, auto b, auto c>
CUTE_HOST_DEVICE constexpr
typename R<gcd(a,b*c),b*c>::type
gcd(R<a,b>, C<c>) {
return {};
}
template <auto c, auto a, auto b>
CUTE_HOST_DEVICE constexpr
typename R<gcd(a,b*c),b*c>::type
gcd(C<c>, R<a,b>) {
return {};
}
template <auto a, auto b>
CUTE_HOST_DEVICE constexpr
typename R<abs(a),abs(b)>::type

View File

@ -46,6 +46,7 @@ template <class T>
static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
using cutlass::bits_to_bytes;
using cutlass::bytes_to_bits;
using cutlass::is_subbyte;

View File

@ -214,6 +214,14 @@ make_smem_ptr(void const* ptr) {
return make_smem_ptr(recast_ptr<T const>(ptr));
}
// nullptr_t overload for make_smem_ptr<float>(nullptr) disambiguation
template <class T>
CUTE_HOST_DEVICE constexpr
auto
make_smem_ptr(decltype(nullptr)) { // nullptr_t
return make_smem_ptr(recast_ptr<T>(nullptr));
}
// The smem tag is invariant over type-recast
template <class NewT, class P>
CUTE_HOST_DEVICE constexpr

View File

@ -30,9 +30,10 @@
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/numeric/numeric_types.hpp> // cute::sizeof_bits
#include <cute/util/type_traits.hpp> // cute::declval, cute::void_t, etc
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/numeric/numeric_types.hpp> // cute::sizeof_bits
#include <cute/numeric/integral_constant.hpp> // Int<0>
#include <cute/util/type_traits.hpp> // cute::declval, cute::void_t, etc
namespace cute
{
@ -115,6 +116,14 @@ raw_pointer_cast(T* ptr) {
return ptr;
}
// The statically-known alignment of a dynamic pointer is unknown
template <class T>
CUTE_HOST_DEVICE constexpr
Int<0>
max_alignment(T*) {
return {};
}
//
// A very simplified iterator adaptor.
// Derived classed may override methods, but be careful to reproduce interfaces exactly.
@ -169,6 +178,13 @@ raw_pointer_cast(iter_adaptor<I,D> const& x) {
return raw_pointer_cast(x.ptr_);
}
template <class I, class D>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(iter_adaptor<I,D> const& x) {
return max_alignment(x.ptr_);
}
//
// counting iterator -- quick and dirty
//

View File

@ -147,6 +147,14 @@ recast_ptr(swizzle_ptr<SwizzleFn,P> const& ptr) {
return make_swizzle_ptr(recast_ptr<NewT>(ptr.get()), SwizzleFn{});
}
// The statically-known alignment of a swizzle pointer is the alignment of the swizzle function converted to bits
template <class SwizzleFn, class P>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(swizzle_ptr<SwizzleFn,P> const&) {
return Int<8>{} * max_alignment(SwizzleFn{});
}
//
// Display utilities
//

View File

@ -447,7 +447,7 @@ recast_layout(Swizzle<B,M,S> const& swizzle)
return upcast<scale::num>(swizzle);
}
else {
static_assert(dependent_false<scale>, "Recast not supported.");
return downcast<scale::den>(upcast<scale::num>(layout));
}
CUTE_GCC_UNREACHABLE;
}
@ -457,7 +457,7 @@ CUTE_HOST_DEVICE constexpr
auto
max_alignment(Swizzle<B,M,S> const&)
{
return Int<1 << M>{};
return Int<(1 << M)>{};
}
template <int B, int M, int S, class Offset, class LayoutB>

View File

@ -84,6 +84,8 @@ struct ArrayEngine
};
// Specialization for sparse_elem<S,T> tensor allocation/iteration
// NOTE: This can and should be used for allocation of SMEM as well!
// Fuse these two ArrayEngines?
template <int S, class T, size_t N>
struct ArrayEngine<sparse_elem<S,T>, N>
{
@ -858,6 +860,17 @@ max_common_layout(Tensor<SrcEngine,SrcLayout> const& a,
CUTE_GCC_UNREACHABLE;
}
/* Return the maximum (statically known) alignment of a Tensor in the number of bits
*/
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
max_alignment(Tensor<Engine,Layout> const& t)
{
return gcd(max_alignment(t.data()),
max_alignment(t.layout()) * static_value<sizeof_bits<typename Engine::value_type>>());
}
//
// Key algebraic operations -- Composition, Divide, and Product
//

View File

@ -123,7 +123,7 @@ bool
block([[maybe_unused]] int bid)
{
#if defined(__CUDA_ARCH__)
return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid;
return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == static_cast<unsigned int>(bid);
#else
return true;
#endif
@ -134,7 +134,7 @@ bool
thread([[maybe_unused]] int tid, [[maybe_unused]] int bid)
{
#if defined(__CUDA_ARCH__)
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid);
return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == static_cast<unsigned int>(tid)) && block(bid);
#else
return true;
#endif

View File

@ -141,9 +141,15 @@ using CUTE_STL_NAMESPACE::common_type_t;
using CUTE_STL_NAMESPACE::remove_pointer;
using CUTE_STL_NAMESPACE::remove_pointer_t;
using CUTE_STL_NAMESPACE::add_pointer;
using CUTE_STL_NAMESPACE::add_pointer_t;
using CUTE_STL_NAMESPACE::alignment_of;
using CUTE_STL_NAMESPACE::alignment_of_v;
using CUTE_STL_NAMESPACE::is_pointer;
using CUTE_STL_NAMESPACE::is_pointer_v;
// <utility>
using CUTE_STL_NAMESPACE::declval;