Fix a copy error in the SM70 main loop when loading data from smem to rmem (#2540)
This commit is contained in:
@ -214,14 +214,16 @@ struct CollectiveMma<
|
|||||||
// Copy Atom retiling
|
// Copy Atom retiling
|
||||||
//
|
//
|
||||||
|
|
||||||
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx);
|
auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
|
||||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
|
||||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||||
|
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||||
|
|
||||||
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx);
|
auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
|
||||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
|
||||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||||
|
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -239,8 +241,8 @@ struct CollectiveMma<
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Load A, B smem->rmem for k=0
|
// Load A, B smem->rmem for k=0
|
||||||
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||||
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||||
//
|
//
|
||||||
// Mainloop
|
// Mainloop
|
||||||
//
|
//
|
||||||
@ -266,8 +268,8 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
// Load A, B smem->rmem for k+1
|
// Load A, B smem->rmem for k+1
|
||||||
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||||
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||||
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||||
if (k_block == 0)
|
if (k_block == 0)
|
||||||
{
|
{
|
||||||
// Copy gmem to rmem
|
// Copy gmem to rmem
|
||||||
@ -515,14 +517,16 @@ struct CollectiveMma<
|
|||||||
// Copy Atom retiling
|
// Copy Atom retiling
|
||||||
//
|
//
|
||||||
|
|
||||||
auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx);
|
auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma);
|
||||||
Tensor tCsA = thr_copy_A.partition_S(sA);
|
auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx);
|
||||||
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
Tensor tCsA = thr_copy_A.partition_S(sA);
|
||||||
|
Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||||
|
|
||||||
auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx);
|
auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma);
|
||||||
Tensor tCsB = thr_copy_B.partition_S(sB);
|
auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx);
|
||||||
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
Tensor tCsB = thr_copy_B.partition_S(sB);
|
||||||
|
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);
|
||||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -536,8 +540,8 @@ struct CollectiveMma<
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Load A, B smem->rmem for k=0
|
// Load A, B smem->rmem for k=0
|
||||||
copy(tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0));
|
||||||
copy(tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0));
|
||||||
//
|
//
|
||||||
// Mainloop
|
// Mainloop
|
||||||
//
|
//
|
||||||
@ -563,8 +567,8 @@ struct CollectiveMma<
|
|||||||
|
|
||||||
// Load A, B smem->rmem for k+1
|
// Load A, B smem->rmem for k+1
|
||||||
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||||
copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next));
|
||||||
copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next));
|
||||||
if (k_block == 0)
|
if (k_block == 0)
|
||||||
{
|
{
|
||||||
if (k_tile_count <= 0) {
|
if (k_tile_count <= 0) {
|
||||||
|
|||||||
Reference in New Issue
Block a user