v3.8.0 update (#2082)
* 3.8 update * fix Markus' name --------- Co-authored-by: yuzhai <yuzhai@nvidia.com>
This commit is contained in:
@ -171,20 +171,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print( sA); print("\n");
|
||||
print(" sB: "); print( sB); print("\n");
|
||||
print(thr_mma);
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREDICATION
|
||||
//
|
||||
@ -200,7 +186,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
// Allocate the preds for MMA- and MMA_MN-modes
|
||||
Tensor tCpA = make_tensor<bool>(make_shape(size<0>(tCsA), size<1>(tCsA)));
|
||||
Tensor tCpB = make_tensor<bool>(make_shape(size<0>(tCsB), size<1>(tCsB)));
|
||||
|
||||
// Populate the predicates on M and N
|
||||
CUTE_UNROLL
|
||||
for (int i = 0; i < size(tCpA); ++i) {
|
||||
@ -210,18 +195,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
for (int i = 0; i < size(tCpB); ++i) {
|
||||
tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB));
|
||||
}
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" cA: "); print( cA); print("\n");
|
||||
print(" cB: "); print( cB); print("\n");
|
||||
print("tCcA: "); print(tCcA); print("\n");
|
||||
print("tCcB: "); print(tCcB); print("\n");
|
||||
print_tensor(tCpA);
|
||||
print_tensor(tCpB);
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH k_block = 0
|
||||
// Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
|
||||
@ -330,24 +303,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx,
|
||||
Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sA: "); print(sA); print("\n");
|
||||
print(" sB: "); print(sB); print("\n");
|
||||
print(thr_mma); print("\n");
|
||||
print("tCrA: "); print(tCrA); print("\n");
|
||||
print("tCrB: "); print(tCrB); print("\n");
|
||||
print("tCrC: "); print(tCrC); print("\n");
|
||||
print(smem_thr_copy_A); print("\n");
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n");
|
||||
print(smem_thr_copy_B); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
@ -434,14 +389,6 @@ cooperative_gemm(uint32_t thread_idx,
|
||||
|
||||
// Clear accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print(" sC: "); print(sC); print("\n");
|
||||
print(" tCsC: "); print(tCsC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr (is_constant<true, decltype(compat)>::value) {
|
||||
detail::cooperative_gemm_no_predication(
|
||||
thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op
|
||||
|
||||
@ -248,15 +248,6 @@ copy(AutoVectorizingCopyWithAssumedAlignment<MaxVecBits> const&,
|
||||
// Recast
|
||||
Tensor src_v = recast<SrcVecType>(src);
|
||||
Tensor dst_v = recast<DstVecType>(dst);
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits);
|
||||
print(" "); print(src); print(" => "); print(src_v); print("\n");
|
||||
print(" "); print(dst); print(" => "); print(dst_v); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy_if(TrivialPredTensor{}, src_v, dst_v);
|
||||
} else {
|
||||
return copy_if(TrivialPredTensor{}, src, dst);
|
||||
@ -374,15 +365,6 @@ copy(Copy_Traits<SM90_BULK_COPY_AUTO, CT_Args...> const& atom, // Copy_Traits m
|
||||
// Construct a new concrete Atom of the vector size
|
||||
using BulkAtom = Copy_Atom<Copy_Traits<BULK_COPY_OP, Int<vec_bits>, CT_Args...>, SrcType>;
|
||||
auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; });
|
||||
|
||||
#if 0
|
||||
if (thread0()) {
|
||||
print("copy blkcp -- found a max_common_layout of "); print(tiler); print("\n");
|
||||
print(" "); print(src); print("\n");
|
||||
print(" "); print(dst); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user