Fix a copy error in the SM70 main loop when loading data from smem to rmem (#2540)

This commit is contained in:
starwang1024
2025-08-11 10:42:01 +08:00
committed by GitHub
parent d0eada85a3
commit 9e6ab77d27

View File

@ -214,14 +214,16 @@ struct CollectiveMma<
// Copy Atom retiling // Copy Atom retiling
// //
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
Tensor tCsA = thr_copy_A.partition_S(sA); auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
Tensor tCsB = thr_copy_B.partition_S(sB); auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// //
@ -239,8 +241,8 @@ struct CollectiveMma<
__syncthreads(); __syncthreads();
// Load A, B smem->rmem for k=0 // Load A, B smem->rmem for k=0
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
// //
// Mainloop // Mainloop
// //
@ -266,8 +268,8 @@ struct CollectiveMma<
// Load A, B smem->rmem for k+1 // Load A, B smem->rmem for k+1
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
if (k_block == 0) if (k_block == 0)
{ {
// Copy gmem to rmem // Copy gmem to rmem
@ -515,14 +517,16 @@ struct CollectiveMma<
// Copy Atom retiling // Copy Atom retiling
// //
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
Tensor tCsA = thr_copy_A.partition_S(sA); auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
Tensor tCsB = thr_copy_B.partition_S(sB); auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor tCsB = thr_copy_B.partition_S(sB);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// //
@ -536,8 +540,8 @@ struct CollectiveMma<
__syncthreads(); __syncthreads();
// Load A, B smem->rmem for k=0 // Load A, B smem->rmem for k=0
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
// //
// Mainloop // Mainloop
// //
@ -563,8 +567,8 @@ struct CollectiveMma<
// Load A, B smem->rmem for k+1 // Load A, B smem->rmem for k+1
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
if (k_block == 0) if (k_block == 0)
{ {
if (k_tile_count <= 0) { if (k_tile_count <= 0) {