v3.8.0 update (#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
Yujia Zhai
2025-02-06 18:33:40 -08:00
committed by GitHub
parent affd1b693d
commit 833f6990e0
168 changed files with 24945 additions and 3436 deletions

View File

@ -171,20 +171,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
#if 0
if (thread0()) {
print(" sA: "); print( sA); print("\n");
print(" sB: "); print( sB); print("\n");
print(thr_mma);
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
}
#endif
//
// PREDICATION
//
@ -200,7 +186,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
// Allocate the preds for MMA- and MMA_MN-modes
Tensor tCpA = make_tensor<bool>(make_shape(size<0>(tCsA), size<1>(tCsA)));
Tensor tCpB = make_tensor<bool>(make_shape(size<0>(tCsB), size<1>(tCsB)));
// Populate the predicates on M and N
CUTE_UNROLL
for (int i = 0; i < size(tCpA); ++i) {
@ -210,18 +195,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
for (int i = 0; i < size(tCpB); ++i) {
tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB));
}
#if 0
if (thread0()) {
print(" cA: "); print( cA); print("\n");
print(" cB: "); print( cB); print("\n");
print("tCcA: "); print(tCcA); print("\n");
print("tCcB: "); print(tCcB); print("\n");
print_tensor(tCpA);
print_tensor(tCpB);
}
#endif
//
// PREFETCH k_block = 0
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
@ -330,24 +303,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
#if 0
if (thread0()) {
print(" sA: "); print(sA); print("\n");
print(" sB: "); print(sB); print("\n");
print(thr_mma); print("\n");
print("tCrA: "); print(tCrA); print("\n");
print("tCrB: "); print(tCrB); print("\n");
print("tCrC: "); print(tCrC); print("\n");
print(smem_thr_copy_A); print("\n");
print("tCsA: "); print(tCsA); print("\n");
print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n");
print(smem_thr_copy_B); print("\n");
print("tCsB: "); print(tCsB); print("\n");
print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n");
}
#endif
//
// PREFETCH
//
@ -434,14 +389,6 @@ cooperative_gemm(uint32_t thread_idx,
// Clear accumulators
clear(tCrC);
#if 0
if (thread0()) {
print(" sC: "); print(sC); print("\n");
print(" tCsC: "); print(tCsC); print("\n");
}
#endif
if constexpr (is_constant<true, decltype(compat)>::value) {
detail::cooperative_gemm_no_predication(
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op

View File

@ -248,15 +248,6 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
// Recast
Tensor src_v = recast<SrcVecType>(src);
Tensor dst_v = recast<DstVecType>(dst);
#if 0
if (thread0()) {
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits);
print(" "); print(src); print(" => "); print(src_v); print("\n");
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
}
#endif
return copy_if(TrivialPredTensor{}, src_v, dst_v);
} else {
return copy_if(TrivialPredTensor{}, src, dst);
@ -374,15 +365,6 @@ copy(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom, // Copy_Traits m
// Construct a new concrete Atom of the vector size
using BulkAtom = Copy_Atom<Copy_Traits<BULK_COPY_OP, Int<vec_bits>, CT_Args...>, SrcType>;
auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; });
#if 0
if (thread0()) {
print("copy blkcp -- found a max_common_layout of "); print(tiler); print("\n");
print(" "); print(src); print("\n");
print(" "); print(dst); print("\n");
}
#endif
return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler));
}