diff --git a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp index a0c8f2a8..a1b6f858 100644 --- a/include/cutlass/gemm/collective/sm70_mma_twostage.hpp +++ b/include/cutlass/gemm/collective/sm70_mma_twostage.hpp @@ -214,14 +214,16 @@ struct CollectiveMma< // Copy Atom retiling // - auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsA = thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsB = thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N // @@ -239,8 +241,8 @@ struct CollectiveMma< __syncthreads(); // Load A, B smem->rmem for k=0 - copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); - copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); // // Mainloop // @@ -266,8 +268,8 @@ struct CollectiveMma< // Load A, B smem->rmem for k+1 int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static - copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); - copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); if (k_block == 0) { // Copy gmem to rmem @@ -515,14 +517,16 @@ struct CollectiveMma< // Copy Atom retiling // - auto thr_copy_A = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsA = thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + auto smem_tiled_copy_a = make_tiled_copy_A(SmemCopyAtomA{}, tiled_mma); + auto thr_copy_A = smem_tiled_copy_a.get_thread_slice(thread_idx); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M - auto thr_copy_B = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma).get_thread_slice(thread_idx); - Tensor tCsB = thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + auto smem_tiled_copy_b = make_tiled_copy_B(SmemCopyAtomB{}, tiled_mma); + auto thr_copy_B = smem_tiled_copy_b.get_thread_slice(thread_idx); + Tensor tCsB = thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N // @@ -536,8 +540,8 @@ struct CollectiveMma< __syncthreads(); // Load A, B smem->rmem for k=0 - copy(tCsA(_,_,0), tCrA_copy_view(_,_,0)); - copy(tCsB(_,_,0), tCrB_copy_view(_,_,0)); + copy(smem_tiled_copy_a, tCsA(_,_,0), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_b, tCsB(_,_,0), tCrB_copy_view(_,_,0)); // // Mainloop // @@ -563,8 +567,8 @@ struct CollectiveMma< // Load A, B smem->rmem for k+1 int k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static - copy(tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); - copy(tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_a, tCsA(_,_,k_block_next), tCrA_copy_view(_,_,k_block_next)); + copy(smem_tiled_copy_b, tCsB(_,_,k_block_next), tCrB_copy_view(_,_,k_block_next)); if (k_block == 0) { if (k_tile_count <= 0) {