CUTLASS 3.5.1 (#1623)

* CUTLASS 3.5.1

* updates, optimizations, fixes
This commit is contained in:
Vijay Thakkar
2024-07-29 08:46:24 -04:00
committed by GitHub
parent 56b46e2d13
commit be60a0b272
312 changed files with 19793 additions and 6775 deletions

View File

@ -39,7 +39,7 @@
#include <cute/algorithm/functional.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cute/tensor.hpp>
#include <cute/tensor_impl.hpp>
namespace cute
{
@ -76,29 +76,15 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
using TypeB = typename TB::value_type;
using TypeC = typename TC::value_type;
// Original, static size of the problem
auto M = size<0>(sC);
auto N = size<1>(sC);
auto K = size<1>(sA);
// Block size of the compute tile
auto BLK_M = tile_size<0>(thr_mma);
auto BLK_N = tile_size<1>(thr_mma);
auto BLK_K = tile_size<2>(thr_mma);
//
// MMA Partitioning
//
// Round the layout extents up to BLK_X to satisfy MMA partitioning safety
Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K)));
Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K)));
Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N)));
// Partition the sA, sB, and sC tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N)
// Partition the sA and sB tiles across the threads for the MMA
Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K)
Tensor tCsB = thr_mma.partition_B(rounded_sB); // (MMA,MMA_N,MMA_K)
Tensor tCsC = thr_mma.partition_C(rounded_sC); // (MMA,MMA_M,MMA_N)
// Create register tensors for the MMA to operate on
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
@ -109,9 +95,6 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
print(" sA: "); print( sA); print("\n");
print(" sB: "); print( sB); print("\n");
print(" sC: "); print( sC); print("\n");
print("r_sA: "); print(rounded_sA); print("\n");
print("r_sB: "); print(rounded_sB); print("\n");
print("r_sC: "); print(rounded_sC); print("\n");
print(thr_mma);
print("tCsA: "); print(tCsA); print("\n");
print("tCsB: "); print(tCsB); print("\n");
@ -127,8 +110,8 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
//
// Create coordinate tensors for the problem
Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k)
Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k)
Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k)
Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k)
// Repeat partitioning with thr_mma
Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k)
@ -222,7 +205,7 @@ cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
//
// Create coordinate tensors for the problem
Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n)
Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n)
// Repeat partitioning with thr_mma
Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n)