@ -39,7 +39,7 @@
|
||||
#include <cute/algorithm/functional.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/tensor_impl.hpp>
|
||||
|
||||
namespace cute
|
||||
{
|
||||
@ -76,29 +76,15 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
using TypeB = typename TB::value_type;
|
||||
using TypeC = typename TC::value_type;
|
||||
|
||||
// Original, static size of the problem
|
||||
auto M = size<0>(sC);
|
||||
auto N = size<1>(sC);
|
||||
auto K = size<1>(sA);
|
||||
|
||||
// Block size of the compute tile
|
||||
auto BLK_M = tile_size<0>(thr_mma);
|
||||
auto BLK_N = tile_size<1>(thr_mma);
|
||||
auto BLK_K = tile_size<2>(thr_mma);
|
||||
|
||||
//
|
||||
// MMA Partitioning
|
||||
//
|
||||
|
||||
// Round the layout extents up to BLK_X to satisfy MMA partitioning safety
|
||||
Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K)));
|
||||
Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K)));
|
||||
Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N)));
|
||||
// Partition the sA, sB, and sC tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Partition the sA and sB tiles across the threads for the MMA
|
||||
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
|
||||
// Create register tensors for the MMA to operate on
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
@ -109,9 +95,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
print(" sA: "); print( sA); print("\n");
|
||||
print(" sB: "); print( sB); print("\n");
|
||||
print(" sC: "); print( sC); print("\n");
|
||||
print("r_sA: "); print(rounded_sA); print("\n");
|
||||
print("r_sB: "); print(rounded_sB); print("\n");
|
||||
print("r_sC: "); print(rounded_sC); print("\n");
|
||||
print(thr_mma);
|
||||
print("tCsA: "); print(tCsA); print("\n");
|
||||
print("tCsB: "); print(tCsB); print("\n");
|
||||
@ -127,8 +110,8 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k)
|
||||
Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k)
|
||||
Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k)
|
||||
Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k)
|
||||
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k)
|
||||
@ -222,7 +205,7 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
|
||||
//
|
||||
|
||||
// Create coordinate tensors for the problem
|
||||
Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n)
|
||||
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
|
||||
// Repeat partitioning with thr_mma
|
||||
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user