Fix a copy error in the SM70 main loop when loading data from smem to rmem (#2540)
This commit is contained in:
@ -214,14 +214,16 @@ struct CollectiveMma<
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx);
|
||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
|
||||
auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||
|
||||
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx);
|
||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||
auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
|
||||
auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
|
||||
//
|
||||
@ -239,8 +241,8 @@ struct CollectiveMma<
|
||||
__syncthreads();
|
||||
|
||||
// Load A, B smem->rmem for k=0
|
||||
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||
copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||
copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
@ -266,8 +268,8 @@ struct CollectiveMma<
|
||||
|
||||
// Load A, B smem->rmem for k+1
|
||||
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||
copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||
copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// Copy gmem to rmem
|
||||
@ -515,14 +517,16 @@ struct CollectiveMma<
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx);
|
||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
|
||||
auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
|
||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||
|
||||
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx);
|
||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||
auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
|
||||
auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
|
||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
|
||||
//
|
||||
@ -536,8 +540,8 @@ struct CollectiveMma<
|
||||
__syncthreads();
|
||||
|
||||
// Load A, B smem->rmem for k=0
|
||||
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||
copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||
copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
@ -563,8 +567,8 @@ struct CollectiveMma<
|
||||
|
||||
// Load A, B smem->rmem for k+1
|
||||
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||
copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||
copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
if (k_tile_count <= 0) {
|
||||
|
||||
Reference in New Issue
Block a user