CUTLASS 3.5.0 (#1411)
This commit is contained in:
@ -29,8 +29,23 @@
|
||||
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_nt_1
|
||||
sgemm_nt_1.cu
|
||||
sgemm_1
|
||||
sgemm_1.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_2
|
||||
sgemm_2.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_sm70
|
||||
sgemm_sm70.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
sgemm_sm80
|
||||
sgemm_sm80.cu
|
||||
)
|
||||
|
||||
cutlass_example_add_executable(
|
||||
|
||||
469
examples/cute/tutorial/sgemm_1.cu
Normal file
469
examples/cute/tutorial/sgemm_1.cu
Normal file
@ -0,0 +1,469 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class AThreadLayout,
|
||||
class TB, class BStride, class BSmemLayout, class BThreadLayout,
|
||||
class TC, class CStride, class CSmemLayout, class CThreadLayout,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
|
||||
TC * C, CStride dC, CSmemLayout , CThreadLayout tC,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
static_assert(is_static<AThreadLayout>::value);
|
||||
static_assert(is_static<BThreadLayout>::value);
|
||||
static_assert(is_static<CThreadLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
static_assert(is_static<CSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ASmemLayout>];
|
||||
__shared__ TB smemB[cosize_v<BSmemLayout>];
|
||||
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles
|
||||
|
||||
Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
|
||||
Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
|
||||
|
||||
Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
|
||||
Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K
|
||||
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via projections of a ThreadLayout tC
|
||||
|
||||
// Partition sA (M,K) by the rows of tC
|
||||
Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
|
||||
// Partition sB (N,K) by the cols of tC
|
||||
Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
|
||||
// Partition gC (M,N) by the tile of tC
|
||||
Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
|
||||
|
||||
// Allocate the accumulators -- same shape/layout as the partitioned data
|
||||
Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
|
||||
// and then computes on those tiles.
|
||||
// copy(.) operates on the global and shared memory via the tA|tB partitioning
|
||||
// gemm(.) operates on the shared and register memory via the tC partitioning
|
||||
|
||||
auto K_TILE_MAX = size<2>(tAgA);
|
||||
|
||||
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
|
||||
{
|
||||
// Copy gmem to smem with tA|tB thread-partitioned tensors
|
||||
copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K)
|
||||
copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K)
|
||||
|
||||
// TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
|
||||
// Tensor tAgAk = tAgA(_,_,k_tile);
|
||||
// CUTE_UNROLL
|
||||
// for (int i = 0; i < size(tAsA); ++i) {
|
||||
// tAsA(i) = tAgAk(i);
|
||||
// }
|
||||
|
||||
cp_async_fence(); // Label the end of (potential) cp.async instructions
|
||||
cp_async_wait<0>(); // Sync on all (potential) cp.async instructions
|
||||
__syncthreads(); // Wait for all threads to write to smem
|
||||
|
||||
// Compute gemm on tC thread-partitioned smem
|
||||
gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)
|
||||
|
||||
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
|
||||
// CUTE_UNROLL
|
||||
// for (int k = 0; k < size<1>(tCsA); ++k) {
|
||||
// CUTE_UNROLL
|
||||
// for (int m = 0; m < size<0>(tCrC); ++m) {
|
||||
// CUTE_UNROLL
|
||||
// for (int n = 0; n < size<1>(tCrC); ++n) {
|
||||
// tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
__syncthreads(); // Wait for all threads to read from smem
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
|
||||
// TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
|
||||
// CUTE_UNROLL
|
||||
// for (int i = 0; i < size(tCsA); ++i) {
|
||||
// tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
|
||||
// }
|
||||
}
|
||||
|
||||
// Setup params for an NT GEMM
|
||||
// Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define NT strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
|
||||
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx
|
||||
|
||||
dim3 dimBlock(size(tC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, tA,
|
||||
B, dB, sB, tB,
|
||||
C, dC, sC, tC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a TN GEMM
|
||||
// Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major
|
||||
auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major
|
||||
|
||||
dim3 dimBlock(size(tC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, tA,
|
||||
B, dB, sB, tB,
|
||||
C, dC, sC, tC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
|
||||
|
||||
cute::device_init(0);
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
return 0;
|
||||
}
|
||||
523
examples/cute/tutorial/sgemm_2.cu
Normal file
523
examples/cute/tutorial/sgemm_2.cu
Normal file
@ -0,0 +1,523 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB,
|
||||
class TC, class CStride, class CSmemLayout, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
|
||||
TC * C, CStride dC, CSmemLayout , TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
static_assert(is_static<CSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ASmemLayout>];
|
||||
__shared__ TB smemB[cosize_v<BSmemLayout>];
|
||||
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via a TiledCopy
|
||||
|
||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
|
||||
// Allocate registers same shape/layout as partitioned data
|
||||
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
|
||||
|
||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
|
||||
// Allocate registers same shape/layout as partitioned data
|
||||
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
|
||||
|
||||
// Copy gmem to rmem for k_tile=0
|
||||
copy(copy_a, tAgA(_,_,_,0), tArA);
|
||||
copy(copy_b, tBgB(_,_,_,0), tBrB);
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via a TiledMMA
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
print("tArA : "); print(tArA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
print("tArA : "); print(tArA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// TUTORIAL: Example of an inner loop that pipelines compute with reads
|
||||
// from global memory by staging through register and shared memory.
|
||||
// Data is read from global to registers, then to shared via the TiledCopy partitions
|
||||
// gemm(.) operates on the shared memory directly via the TiledMMA partitions
|
||||
|
||||
auto K_TILE_MAX = size<3>(tAgA);
|
||||
|
||||
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
|
||||
{
|
||||
// Copy rmem to smem with tA|tB thread-partitioned tensors
|
||||
__syncthreads(); // Wait for all threads to consume smem
|
||||
copy(tArA, tAsA);
|
||||
copy(tBrB, tBsB);
|
||||
__syncthreads(); // Wait for all threads to consume smem
|
||||
|
||||
// Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors
|
||||
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
|
||||
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
|
||||
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
|
||||
// TUTORIAL: The above call to copy(copy_a, tAgA(_,_,_,k_tile_next), tArA) is equivalent to
|
||||
// CUTE_UNROLL
|
||||
// for (int k = 0; k < size<1>(tCsA); ++k) {
|
||||
// CUTE_UNROLL
|
||||
// for (int m = 0; m < size<0>(tCrC); ++m) {
|
||||
// copy_a.call(tAgA(_,m,k), tArA(_,m,k);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute gemm on mma-partitioned smem
|
||||
gemm(mma, tCsA, tCsB, tCrC);
|
||||
// TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
|
||||
// CUTE_UNROLL
|
||||
// for (int k = 0; k < size<1>(tCsA); ++k) {
|
||||
// CUTE_UNROLL
|
||||
// for (int m = 0; m < size<0>(tCrC); ++m) {
|
||||
// CUTE_UNROLL
|
||||
// for (int n = 0; n < size<1>(tCrC); ++n) {
|
||||
// mma.call(tCsA(_,m,k), tCsB(_,n,k), tCrC(_,m,n);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define NT strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
|
||||
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
|
||||
|
||||
// Define the thread layouts (static)
|
||||
|
||||
// TUTORIAL: Construct TiledCopy with a particular Copy_Atom to use and
|
||||
// define the partitioning pattern to apply.
|
||||
// Each thread will (try to) copy 4x1 elements of type TA using 128-bit copy.
|
||||
// Use 32x8 of these threads.
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
|
||||
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
|
||||
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
|
||||
|
||||
// TUTORIAL: Construct TiledMMA with a particular MMA_Atom to use and
|
||||
// define the partitioning pattern to apply.
|
||||
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
|
||||
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 UniversalFMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a TN GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape ( bM, bK),
|
||||
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
|
||||
auto sB = make_layout(make_shape ( bN, bK),
|
||||
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
|
||||
|
||||
// TUTORIAL: Construct TiledCopy to define the Copy_Atom to use and the
|
||||
// partitioning pattern to apply.
|
||||
// Each thread will copy 1x1 elements of type TA.
|
||||
// Use 32x8 of these threads arranged in k-major.
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
|
||||
// TUTORIAL: Construct TiledMMA to define the MMA_Atom to use and the
|
||||
// partitioning pattern to apply.
|
||||
// Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread.
|
||||
// Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads.
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
|
||||
|
||||
cute::device_init(0);
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -1,426 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
# include "cutlass/util/cublas_wrappers.hpp"
|
||||
#endif
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class MShape, class NShape, class KShape,
|
||||
class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
void
|
||||
gemm_device(MShape M, NShape N, KShape K,
|
||||
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
TC * C, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT(is_static<ABlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BBlockLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CBlockLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT(is_static<AThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<BThreadLayout>::value);
|
||||
CUTE_STATIC_ASSERT(is_static<CThreadLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(tA) == size(tC));
|
||||
CUTE_STATIC_ASSERT_V(size(tB) == size(tC));
|
||||
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockA) == shape<0>(blockC)); // BLK_M
|
||||
//CUTE_STATIC_ASSERT_V(shape<0>(blockB) == shape<1>(blockC)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(shape<1>(blockA) == shape<1>(blockB)); // BLK_K
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ABlockLayout>];
|
||||
__shared__ TB smemB[cosize_v<BBlockLayout>];
|
||||
auto sA = make_tensor(make_smem_ptr(smemA), blockA); // (BLK_M,BLK_K)
|
||||
auto sB = make_tensor(make_smem_ptr(smemB), blockB); // (BLK_N,BLK_K)
|
||||
|
||||
// Represent the full tensors
|
||||
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
|
||||
auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K)
|
||||
auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block --
|
||||
// potential for thread block locality
|
||||
auto blk_shape = make_shape(size<0>(sA), size<0>(sB), size<1>(sB));// (BLK_M,BLK_N,BLK_K)
|
||||
auto blk_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
|
||||
auto gA = local_tile(mA, blk_shape, blk_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
auto gB = local_tile(mB, blk_shape, blk_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
auto gC = local_tile(mC, blk_shape, blk_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of simple partitioning of A|B tiles over tA|tB
|
||||
// Default is a raked partition, but can be changed with Step<X,Y> parameter
|
||||
|
||||
auto tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k)
|
||||
auto tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K)
|
||||
|
||||
auto tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k)
|
||||
auto tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via projections of tC
|
||||
|
||||
// Partition sA (M,K) by the rows of tC
|
||||
auto tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K)
|
||||
// Partition sB (N,K) by the cols of tC
|
||||
auto tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K)
|
||||
// Partition gC (M,N) by the tile of tC
|
||||
auto tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N)
|
||||
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
auto tCrC = make_fragment_like(tCgC); // (THR_M,THR_N)
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mA\n");
|
||||
print(mA.shape()); print("\n"); print(mA.stride());
|
||||
print("\n\ngA\n");
|
||||
print(gA.shape()); print("\n"); print(gA.stride());
|
||||
print("\n\ntAgA\n");
|
||||
print(tAgA.shape()); print("\n"); print(tAgA.stride());
|
||||
print("\n\nsA\n");
|
||||
print(sA.shape()); print("\n"); print(sA.stride());
|
||||
print("\n\ntAsA\n");
|
||||
print(tAsA.shape()); print("\n"); print(tAsA.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mB\n");
|
||||
print(mB.shape()); print("\n"); print(mB.stride());
|
||||
print("\n\ngB\n");
|
||||
print(gB.shape()); print("\n"); print(gB.stride());
|
||||
print("\n\ntBgB\n");
|
||||
print(tBgB.shape()); print("\n"); print(tBgB.stride());
|
||||
print("\n\nsB\n");
|
||||
print(sB.shape()); print("\n"); print(sB.stride());
|
||||
print("\n\ntBsB\n");
|
||||
print(tBsB.shape()); print("\n"); print(tBsB.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print("mC\n");
|
||||
print(mC.shape()); print("\n"); print(mC.stride());
|
||||
print("\n\ngC\n");
|
||||
print(gC.shape()); print("\n"); print(gC.stride());
|
||||
print("\n\ntCsA\n");
|
||||
print(tCsA.shape()); print("\n"); print(tCsA.stride());
|
||||
print("\n\ntCsB\n");
|
||||
print(tCsB.shape()); print("\n"); print(tCsB.stride());
|
||||
print("\n\ntCgC\n");
|
||||
print(tCgC.shape()); print("\n"); print(tCgC.stride());
|
||||
print("\n\ntCrC\n");
|
||||
print(tCrC.shape()); print("\n"); print(tCrC.stride());
|
||||
print("\n\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// TUTORIAL: Example of a very simple compute loop
|
||||
// Data is read from global to shared memory via the tA|tB partitioning
|
||||
// gemm(.) operates on the shared memory directly via the tC partitioning
|
||||
|
||||
auto k_max = size<2>(tAgA);
|
||||
|
||||
for (int k = 0; k < k_max; ++k)
|
||||
{
|
||||
// Copy gmem to smem
|
||||
copy(tAgA(_,_,k), tAsA);
|
||||
copy(tBgB(_,_,k), tBsB);
|
||||
|
||||
// In case copy uses cp.async, make sure that the cp.async
|
||||
// instructions are ordered with respect to other cp.async
|
||||
// instructions (fence), then wait on all the outstanding copy
|
||||
// operations (wait<0>()). __syncthreads() alone does not do
|
||||
// this.
|
||||
//
|
||||
// NOTE: cp_async_wait<0>() currently issues cp.async.wait_all.
|
||||
// This is equivalent to cp.async.commit_group followed by
|
||||
// cp.async_wait_group 0. This should make the first
|
||||
// cp_async_fence() (which also issues cp.async.commit_group)
|
||||
// redundant. The tutorial works as-is, so we'll leave the
|
||||
// redundant fence in for now and study its removal later.
|
||||
cp_async_fence();
|
||||
cp_async_wait<0>();
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute gemm on smem
|
||||
gemm(tCsA, tCsB, tCrC);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
|
||||
template <typename TA, typename TB, typename TC,
|
||||
typename Alpha, typename Beta>
|
||||
void
|
||||
gemm(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
|
||||
// Define strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA);
|
||||
auto dB = make_stride(Int<1>{}, ldB);
|
||||
auto dC = make_stride(Int<1>{}, ldC);
|
||||
|
||||
// Define block sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
|
||||
// Define the block layouts (static)
|
||||
auto sA = make_layout(make_shape(bM,bK));
|
||||
auto sB = make_layout(make_shape(bN,bK));
|
||||
auto sC = make_layout(make_shape(bM,bN));
|
||||
|
||||
// Define the thread layouts (static)
|
||||
auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
auto tC = make_layout(make_shape(Int<16>{}, Int<16>{}));
|
||||
|
||||
dim3 dimBlock(size(tC));
|
||||
dim3 dimGrid(ceil_div(size(M), size(bM)),
|
||||
ceil_div(size(N), size(bN)));
|
||||
gemm_device
|
||||
<<< dimGrid, dimBlock, 0, stream >>>
|
||||
(M, N, K,
|
||||
A, dA, sA, tA,
|
||||
B, dB, sB, tB,
|
||||
C, dC, sC, tC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
void test_gemm(int m, int n, int k)
|
||||
{
|
||||
cute::device_init(0);
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
//
|
||||
// cuBLas
|
||||
//
|
||||
|
||||
cublasHandle_t handle;
|
||||
cublasCreate(&handle);
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
&beta,
|
||||
d_C.data().get(), m);
|
||||
CUTE_CHECK_LAST();
|
||||
|
||||
thrust::host_vector<TC> cublas_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
blam::cublas::gemm(handle, CUBLAS_OP_N, CUBLAS_OP_T,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
&beta,
|
||||
d_C.data().get(), m);
|
||||
}
|
||||
double cublas_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUBLAS_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cublas_time, cublas_time*1000);
|
||||
|
||||
#else
|
||||
|
||||
std::cout << "Verification by comparison with cuBLAS is disabled, "
|
||||
"either because the CMake option CUTLASS_ENABLE_CUBLAS "
|
||||
"was explicitly set to OFF, or because CMake could not find cuBLAS. "
|
||||
"If you would like to enable verification with cuBLAS, "
|
||||
"please set the CMake option CUTLASS_ENABLE_CUBLAS to ON, "
|
||||
"rerun CMake, and recompile this example.\n";
|
||||
|
||||
#endif // CUTLASS_ENABLE_CUBLAS
|
||||
|
||||
//
|
||||
// CuTe
|
||||
//
|
||||
|
||||
// Run once (and check)
|
||||
d_C = h_C;
|
||||
gemm(m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
beta,
|
||||
d_C.data().get(), m);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), m,
|
||||
d_B.data().get(), n,
|
||||
beta,
|
||||
d_C.data().get(), m);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
#if defined(CUTLASS_ENABLE_CUBLAS) && CUTLASS_ENABLE_CUBLAS != 0
|
||||
printf("Empirical Perf: %.1f%%\n", (cublas_time / cute_time) * 100);
|
||||
|
||||
auto host_matrix_to_const_column_major_cute_tensor =
|
||||
[](const auto& X, int num_rows, int num_cols, int LDX) {
|
||||
const auto shape = cute::Shape<int, int>{num_rows, num_cols};
|
||||
const auto strides = cute::Stride<int, int>{1, LDX};
|
||||
return cute::make_tensor(X.data(), cute::make_layout(shape, strides));
|
||||
};
|
||||
|
||||
const auto A_view = host_matrix_to_const_column_major_cute_tensor(h_A, m, k, m);
|
||||
// B^T is k x n, so B is n x k.
|
||||
const auto B_view = host_matrix_to_const_column_major_cute_tensor(h_B, n, k, n);
|
||||
const auto C_computed_view = host_matrix_to_const_column_major_cute_tensor(cute_result, m, n, m);
|
||||
const auto C_expected_view = host_matrix_to_const_column_major_cute_tensor(cublas_result, m, n, m);
|
||||
print_matrix_multiply_mollified_relative_error("float", A_view, B_view, C_computed_view, C_expected_view);
|
||||
|
||||
#endif // CUTLASS_ENABLE_CUBLAS
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
test_gemm(m, n, k);
|
||||
|
||||
return 0;
|
||||
}
|
||||
526
examples/cute/tutorial/sgemm_sm70.cu
Normal file
526
examples/cute/tutorial/sgemm_sm70.cu
Normal file
@ -0,0 +1,526 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB,
|
||||
class TC, class CStride, class CSmemLayout, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
|
||||
TC * C, CStride dC, CSmemLayout , TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
static_assert(is_static<CSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ASmemLayout>];
|
||||
__shared__ TB smemB[cosize_v<BSmemLayout>];
|
||||
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via a TiledCopy
|
||||
|
||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K)
|
||||
Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K)
|
||||
|
||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K)
|
||||
Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K
|
||||
|
||||
// Copy gmem to rmem for k_tile=0
|
||||
copy(copy_a, tAgA(_,_,_,0), tArA);
|
||||
copy(copy_b, tBgB(_,_,_,0), tBrB);
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
// TUTORIAL: Example of partitioning via a TiledMMA
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate registers for pipelining
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K)
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
print("tArA : "); print(tArA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
print("tArA : "); print(tArA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// Copy rmem to smem
|
||||
copy(tArA, tAsA);
|
||||
copy(tBrB, tBsB);
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
// TUTORIAL: Example of a gemm loop that pipelines shared memory AND register memory
|
||||
// Data is read from global to registers, then to shared via the tA|tB partitions
|
||||
// Data is then copied from shared to registers in multiple waves via the tC partitions
|
||||
// and gemm(.) operates on the current register wave
|
||||
//
|
||||
|
||||
// Load A, B shmem->regs for k_block=0
|
||||
copy(tCsA(_,_,0), tCrA(_,_,0));
|
||||
copy(tCsB(_,_,0), tCrB(_,_,0));
|
||||
auto K_TILE_MAX = size<3>(tAgA);
|
||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
|
||||
{
|
||||
// Pipeline the k-mode of the block registers
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// Copy rmem to smem
|
||||
__syncthreads();
|
||||
copy(tArA, tAsA);
|
||||
copy(tBrB, tBsB);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Copy smem to rmem for k_block+1
|
||||
int k_block_next = (k_block + 1) % K_BLOCK_MAX;
|
||||
copy(tCsA(_,_,k_block_next), tCrA(_,_,k_block_next));
|
||||
copy(tCsB(_,_,k_block_next), tCrB(_,_,k_block_next));
|
||||
if (k_block == 0)
|
||||
{
|
||||
// Copy gmem to rmem for k_tile+1
|
||||
int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile;
|
||||
copy(copy_a, tAgA(_,_,_,k_tile_next), tArA);
|
||||
copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB);
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
} // k_block
|
||||
} // k_tile
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define NT strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major
|
||||
auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
|
||||
|
||||
// Define the thread layouts (static)
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TA>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
|
||||
Layout<Shape< _4,_1>>{}); // Val layout 4x1 m-major
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, TB>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
|
||||
Layout<Shape< _4,_1>>{}); // Val layout 4x1 n-major
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a TN GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape ( bM, bK),
|
||||
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
|
||||
auto sB = make_layout(make_shape ( bN, bK),
|
||||
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
|
||||
|
||||
// Define the thread layouts (static)
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<UniversalCopy<TA>, TA>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<UniversalCopy<TB>, TB>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 7) {
|
||||
std::cout << "This example requires an Volta GPU or newer (CC >= 70)" << std::endl;
|
||||
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
|
||||
return 0;
|
||||
}
|
||||
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
return 0;
|
||||
}
|
||||
567
examples/cute/tutorial/sgemm_sm80.cu
Normal file
567
examples/cute/tutorial/sgemm_sm80.cu
Normal file
@ -0,0 +1,567 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#include <cstdlib>
|
||||
#include <cstdio>
|
||||
#include <cassert>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include "cutlass/util/print_error.hpp"
|
||||
#include "cutlass/util/GPU_Clock.hpp"
|
||||
#include "cutlass/util/helper_cuda.hpp"
|
||||
|
||||
template <class ProblemShape, class CtaTiler,
|
||||
class TA, class AStride, class ASmemLayout, class TiledCopyA,
|
||||
class TB, class BStride, class BSmemLayout, class TiledCopyB,
|
||||
class TC, class CStride, class CSmemLayout, class TiledMma,
|
||||
class Alpha, class Beta>
|
||||
__global__ static
|
||||
__launch_bounds__(decltype(size(TiledMma{}))::value)
|
||||
void
|
||||
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
|
||||
TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
|
||||
TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
|
||||
TC * C, CStride dC, CSmemLayout , TiledMma mma,
|
||||
Alpha alpha, Beta beta)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Preconditions
|
||||
CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K)
|
||||
CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads
|
||||
CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads
|
||||
|
||||
static_assert(is_static<ASmemLayout>::value);
|
||||
static_assert(is_static<BSmemLayout>::value);
|
||||
static_assert(is_static<CSmemLayout>::value);
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M
|
||||
CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N
|
||||
CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K
|
||||
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK
|
||||
CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN
|
||||
|
||||
//
|
||||
// Full and Tiled Tensors
|
||||
//
|
||||
|
||||
// Represent the full tensors
|
||||
Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
|
||||
Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
|
||||
Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
|
||||
|
||||
// Get the appropriate blocks for this thread block
|
||||
auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k)
|
||||
Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
||||
Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N)
|
||||
|
||||
// Shared memory buffers
|
||||
__shared__ TA smemA[cosize_v<ASmemLayout>];
|
||||
__shared__ TB smemB[cosize_v<BSmemLayout>];
|
||||
Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
//
|
||||
// Partition the copying of A and B tiles across the threads
|
||||
//
|
||||
|
||||
ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
|
||||
Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k)
|
||||
Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE)
|
||||
|
||||
ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
|
||||
Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k)
|
||||
Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K
|
||||
|
||||
//
|
||||
// PREFETCH
|
||||
//
|
||||
|
||||
auto K_PIPE_MAX = size<3>(tAsA);
|
||||
|
||||
// Total count of tiles
|
||||
int k_tile_count = size<3>(tAgA);
|
||||
// Current tile index in gmem to read from
|
||||
int k_tile_next = 0;
|
||||
|
||||
// Start async loads for all pipes but the last
|
||||
CUTE_UNROLL
|
||||
for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) {
|
||||
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe));
|
||||
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe));
|
||||
cp_async_fence();
|
||||
--k_tile_count;
|
||||
if (k_tile_count > 0) { ++k_tile_next; }
|
||||
}
|
||||
|
||||
//
|
||||
// Define A/B partitioning and C accumulators
|
||||
//
|
||||
|
||||
ThrMMA thr_mma = mma.get_slice(threadIdx.x);
|
||||
Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Allocate registers for pipelining
|
||||
Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K)
|
||||
// Allocate the accumulators -- same size as the projected data
|
||||
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
|
||||
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
|
||||
|
||||
// Clear the accumulators
|
||||
clear(tCrC);
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mA : "); print( mA); print("\n");
|
||||
print(" gA : "); print( gA); print("\n");
|
||||
print(" sA : "); print( sA); print("\n");
|
||||
print("tAgA : "); print(tAgA); print("\n");
|
||||
print("tAsA : "); print(tAsA); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mB : "); print( mB); print("\n");
|
||||
print(" gB : "); print( gB); print("\n");
|
||||
print(" sB : "); print( sB); print("\n");
|
||||
print("tBgB : "); print(tBgB); print("\n");
|
||||
print("tBsB : "); print(tBsB); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(thread0()) {
|
||||
print(" mC : "); print( mC); print("\n");
|
||||
print(" gC : "); print( gC); print("\n");
|
||||
print("tCsA : "); print(tCsA); print("\n");
|
||||
print("tCsB : "); print(tCsB); print("\n");
|
||||
print("tCgC : "); print(tCgC); print("\n");
|
||||
print("tCrA : "); print(tCrA); print("\n");
|
||||
print("tCrB : "); print(tCrB); print("\n");
|
||||
print("tCrC : "); print(tCrC); print("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
|
||||
// Current pipe index in smem to read from
|
||||
int smem_pipe_read = 0;
|
||||
// Current pipe index in smem to write to
|
||||
int smem_pipe_write = K_PIPE_MAX-1;
|
||||
|
||||
// Pipe slice
|
||||
Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Size of the register pipeline
|
||||
auto K_BLOCK_MAX = size<2>(tCrA);
|
||||
|
||||
// PREFETCH register pipeline
|
||||
if (K_BLOCK_MAX > 1) {
|
||||
// Wait until our first prefetched tile is loaded in
|
||||
cp_async_wait<K_PIPE_MAX-2>();
|
||||
__syncthreads();
|
||||
|
||||
// Prefetch the first rmem from the first k-tile
|
||||
copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{}));
|
||||
copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
// TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's cp.async instructions
|
||||
// and explicit pipelines in shared memory.
|
||||
// Data is read from global(k_tile_next) to shared(smem_pipe_write).
|
||||
// Data is read from shared(smem_pipe_read) to registers(k_block_next).
|
||||
// Data is computed on registers(b_block).
|
||||
//
|
||||
// This allows all copies and compute to overlap:
|
||||
// Copy from gmem->smem can overlap with copies from smem->rmem and compute on rmem.
|
||||
// Copy from smem->rmem can overlap with compute on rmem.
|
||||
//
|
||||
|
||||
CUTE_NO_UNROLL
|
||||
while (k_tile_count > -(K_PIPE_MAX-1))
|
||||
{
|
||||
CUTE_UNROLL
|
||||
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
|
||||
{
|
||||
if (k_block == K_BLOCK_MAX - 1)
|
||||
{
|
||||
// Slice the smem_pipe_read smem
|
||||
tCsA_p = tCsA(_,_,_,smem_pipe_read);
|
||||
tCsB_p = tCsB(_,_,_,smem_pipe_read);
|
||||
|
||||
// Commit the smem for smem_pipe_read
|
||||
cp_async_wait<K_PIPE_MAX-2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load A, B shmem->regs for k_block+1
|
||||
auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static
|
||||
copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next));
|
||||
copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next));
|
||||
// Copy gmem to smem before computing gemm on each k-pipe
|
||||
if (k_block == 0)
|
||||
{
|
||||
copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write));
|
||||
copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write));
|
||||
cp_async_fence();
|
||||
|
||||
// Advance the gmem tile
|
||||
--k_tile_count;
|
||||
if (k_tile_count > 0) { ++k_tile_next; }
|
||||
|
||||
// Advance the smem pipe
|
||||
smem_pipe_write = smem_pipe_read;
|
||||
++smem_pipe_read;
|
||||
smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read;
|
||||
}
|
||||
// Thread-level register gemm for k_block
|
||||
gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
axpby(alpha, tCrC, beta, tCgC);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_nt(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define NT strides (mixed)
|
||||
auto dA = make_stride(Int<1>{}, ldA); // (dM, dK)
|
||||
auto dB = make_stride(Int<1>{}, ldB); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA = make_layout(make_shape(bM, bK, bP)); // (m,k,p) -> smem_idx; m-major
|
||||
auto sB = make_layout(make_shape(bN, bK, bP)); // (n,k,p) -> smem_idx; n-major
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major
|
||||
|
||||
// Define the thread layouts (static)
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
|
||||
Layout<Shape< _4,_1>>{});// Val layout 4x1 m-major
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
|
||||
Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
|
||||
Layout<Shape< _4,_1>>{});// Val layout 4x1 n-major
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
// Setup params for a NT GEMM
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm_tn(int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
// Define shapes (dynamic)
|
||||
auto M = int(m);
|
||||
auto N = int(n);
|
||||
auto K = int(k);
|
||||
auto prob_shape = make_shape(M, N, K); // (M, N, K)
|
||||
|
||||
// Define TN strides (mixed)
|
||||
auto dA = make_stride(ldA, Int<1>{}); // (dM, dK)
|
||||
auto dB = make_stride(ldB, Int<1>{}); // (dN, dK)
|
||||
auto dC = make_stride(Int<1>{}, ldC); // (dM, dN)
|
||||
|
||||
// Define CTA tile sizes (static)
|
||||
auto bM = Int<128>{};
|
||||
auto bN = Int<128>{};
|
||||
auto bK = Int< 8>{};
|
||||
auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K)
|
||||
auto bP = Int<3>{}; // Pipeline
|
||||
|
||||
// Define the smem layouts (static)
|
||||
auto sA_atom = make_layout(make_shape ( bM, bK),
|
||||
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
|
||||
auto sB_atom = make_layout(make_shape ( bN, bK),
|
||||
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
|
||||
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
|
||||
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
|
||||
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
|
||||
|
||||
// Define the thread layouts (static)
|
||||
|
||||
TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TA>, TA>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TB>, TB>{},
|
||||
Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
|
||||
Layout<Shape< _1,_1>>{}); // Val layout 1x1
|
||||
|
||||
TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
|
||||
Layout<Shape<_16,_16,_1>>{}); // 16x16x1 TiledMMA
|
||||
|
||||
#if 0
|
||||
print(copyA);
|
||||
print(copyB);
|
||||
print(mmaC);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
print_latex(copyA);
|
||||
print_latex(copyB);
|
||||
print_latex(mmaC);
|
||||
#endif
|
||||
|
||||
dim3 dimBlock(size(mmaC));
|
||||
dim3 dimGrid(size(ceil_div(M, bM)),
|
||||
size(ceil_div(N, bN)));
|
||||
gemm_device<<<dimGrid, dimBlock, 0, stream>>>
|
||||
(prob_shape, cta_tiler,
|
||||
A, dA, sA, copyA,
|
||||
B, dB, sB, copyB,
|
||||
C, dC, sC, mmaC,
|
||||
alpha, beta);
|
||||
}
|
||||
|
||||
template <class TA, class TB, class TC,
|
||||
class Alpha, class Beta>
|
||||
void
|
||||
gemm(char transA, char transB, int m, int n, int k,
|
||||
Alpha alpha,
|
||||
TA const* A, int ldA,
|
||||
TB const* B, int ldB,
|
||||
Beta beta,
|
||||
TC * C, int ldC,
|
||||
cudaStream_t stream = 0)
|
||||
{
|
||||
if (transA == 'N' && transB == 'T') {
|
||||
return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
} else
|
||||
if (transA == 'T' && transB == 'N') {
|
||||
return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
|
||||
}
|
||||
assert(false && "Not implemented");
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaDeviceProp props;
|
||||
cudaError_t error = cudaGetDeviceProperties(&props, 0);
|
||||
if (error != cudaSuccess) {
|
||||
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (props.major < 8) {
|
||||
std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
|
||||
// Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
|
||||
return 0;
|
||||
}
|
||||
|
||||
int m = 5120;
|
||||
if (argc >= 2)
|
||||
sscanf(argv[1], "%d", &m);
|
||||
|
||||
int n = 5120;
|
||||
if (argc >= 3)
|
||||
sscanf(argv[2], "%d", &n);
|
||||
|
||||
int k = 4096;
|
||||
if (argc >= 4)
|
||||
sscanf(argv[3], "%d", &k);
|
||||
|
||||
char transA = 'N';
|
||||
if (argc >= 5)
|
||||
sscanf(argv[4], "%c", &transA);
|
||||
|
||||
char transB = 'T';
|
||||
if (argc >= 6)
|
||||
sscanf(argv[5], "%c", &transB);
|
||||
|
||||
using TA = float;
|
||||
using TB = float;
|
||||
using TC = float;
|
||||
using TI = float;
|
||||
|
||||
TI alpha = 1.0;
|
||||
TI beta = 0.0;
|
||||
|
||||
std::cout << "M = " << m << std::endl;
|
||||
std::cout << "N = " << n << std::endl;
|
||||
std::cout << "K = " << k << std::endl;
|
||||
std::cout << "C = A^" << transA << " B^" << transB << std::endl;
|
||||
|
||||
thrust::host_vector<TA> h_A(m*k);
|
||||
thrust::host_vector<TB> h_B(n*k);
|
||||
thrust::host_vector<TC> h_C(m*n);
|
||||
|
||||
for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
|
||||
for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
|
||||
|
||||
thrust::device_vector<TA> d_A = h_A;
|
||||
thrust::device_vector<TB> d_B = h_B;
|
||||
thrust::device_vector<TC> d_C = h_C;
|
||||
|
||||
double gflops = (2.0*m*n*k) * 1e-9;
|
||||
|
||||
const int timing_iterations = 100;
|
||||
GPU_Clock timer;
|
||||
|
||||
int ldA = 0, ldB = 0, ldC = m;
|
||||
|
||||
if (transA == 'N') {
|
||||
ldA = m;
|
||||
} else if (transA == 'T') {
|
||||
ldA = k;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
if (transB == 'N') {
|
||||
ldB = k;
|
||||
} else if (transB == 'T') {
|
||||
ldB = n;
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
// Run once
|
||||
d_C = h_C;
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
CUTE_CHECK_LAST();
|
||||
thrust::host_vector<TC> cute_result = d_C;
|
||||
|
||||
// Timing iterations
|
||||
timer.start();
|
||||
for (int i = 0; i < timing_iterations; ++i) {
|
||||
gemm(transA, transB, m, n, k,
|
||||
alpha,
|
||||
d_A.data().get(), ldA,
|
||||
d_B.data().get(), ldB,
|
||||
beta,
|
||||
d_C.data().get(), ldC);
|
||||
}
|
||||
double cute_time = timer.seconds() / timing_iterations;
|
||||
CUTE_CHECK_LAST();
|
||||
printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -67,7 +67,7 @@
|
||||
//
|
||||
// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N).
|
||||
template <class TensorS, class TensorD, class ThreadLayout>
|
||||
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
@ -77,12 +77,13 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
|
||||
// Construct a partitioning of the tile among threads with the given thread arrangement.
|
||||
|
||||
// Concept: Tensor Layout Index
|
||||
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x);
|
||||
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x);
|
||||
// Concept: Tensor ThrLayout ThrIndex
|
||||
Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
|
||||
Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN)
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's partition
|
||||
auto fragment = make_fragment_like(thr_tile_S);
|
||||
// Use make_tensor to try to match the layout of thr_tile_S
|
||||
Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN)
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(thr_tile_S, fragment);
|
||||
@ -95,17 +96,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout)
|
||||
/// has the precondition that pointers are aligned to the vector size.
|
||||
///
|
||||
template <class TensorS, class TensorD, class ThreadLayout, class VecLayout>
|
||||
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
|
||||
__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout)
|
||||
{
|
||||
using namespace cute;
|
||||
using Element = typename TensorS::value_type;
|
||||
|
||||
// Slice the tensors to obtain a view into each tile.
|
||||
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N)
|
||||
|
||||
// Define `AccessType` which controls the size of the actual memory access.
|
||||
using AccessType = cutlass::AlignedArray<Element, size(shape(VecLayout{}))>;
|
||||
using AccessType = cutlass::AlignedArray<Element, size(VecLayout{})>;
|
||||
|
||||
// A copy atom corresponds to one hardware memory access.
|
||||
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
|
||||
@ -125,29 +126,18 @@ __global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLa
|
||||
// Construct a Tensor corresponding to each thread's slice.
|
||||
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
||||
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S);
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D);
|
||||
Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN)
|
||||
Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN)
|
||||
|
||||
// Construct a register-backed Tensor with the same shape as each thread's partition
|
||||
auto fragment = make_fragment_like(thr_tile_D);
|
||||
// Use make_fragment because the first mode is the instruction-local mode
|
||||
Tensor fragment = make_fragment_like(thr_tile_D); // (CopyOp, CopyM, CopyN)
|
||||
|
||||
// Copy from GMEM to RMEM and from RMEM to GMEM
|
||||
copy(tiled_copy, thr_tile_S, fragment);
|
||||
copy(tiled_copy, fragment, thr_tile_D);
|
||||
}
|
||||
|
||||
/// Helper to convert a shape to a dim3
|
||||
template <class Shape>
|
||||
dim3 shape_to_dim3(Shape shape)
|
||||
{
|
||||
using namespace cute;
|
||||
|
||||
CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{});
|
||||
auto result = append<3>(product_each(shape), 1u);
|
||||
|
||||
return dim3(get<0>(result), get<1>(result), get<2>(result));
|
||||
}
|
||||
|
||||
/// Main function
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
@ -161,13 +151,13 @@ int main(int argc, char** argv)
|
||||
// Define a tensor shape with dynamic extents (m, n)
|
||||
auto tensor_shape = make_shape(256, 512);
|
||||
|
||||
//
|
||||
// Allocate and initialize
|
||||
//
|
||||
|
||||
thrust::host_vector<Element> h_S(size(tensor_shape));
|
||||
thrust::host_vector<Element> h_D(size(tensor_shape));
|
||||
|
||||
//
|
||||
// Initialize
|
||||
//
|
||||
|
||||
for (size_t i = 0; i < h_S.size(); ++i) {
|
||||
h_S[i] = static_cast<Element>(i);
|
||||
h_D[i] = Element{};
|
||||
@ -180,33 +170,36 @@ int main(int argc, char** argv)
|
||||
// Make tensors
|
||||
//
|
||||
|
||||
Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape));
|
||||
Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape));
|
||||
Tensor tensor_S = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), make_layout(tensor_shape));
|
||||
Tensor tensor_D = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), make_layout(tensor_shape));
|
||||
|
||||
//
|
||||
// Partition
|
||||
// Tile tensors
|
||||
//
|
||||
|
||||
|
||||
// Define a statically sized block (M, N).
|
||||
//
|
||||
// Note, by convention, capital letters are used to represent static modes.
|
||||
auto block_shape = make_shape(Int<128>{}, Int<64>{});
|
||||
|
||||
if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) {
|
||||
if ((size<0>(tensor_shape) % size<0>(block_shape)) || (size<1>(tensor_shape) % size<1>(block_shape))) {
|
||||
std::cerr << "The tensor shape must be divisible by the block shape." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
// Equivalent check to the above
|
||||
if (not weakly_compatible(block_shape, tensor_shape)) {
|
||||
std::cerr << "Expected the tensors to be weakly compatible with the block_shape." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||
// Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile
|
||||
// shape, and modes (m', n') correspond to the number of tiles.
|
||||
//
|
||||
// These will be used to determine the CUDA kernel grid dimensinos.
|
||||
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape);
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape);
|
||||
//
|
||||
// These will be used to determine the CUDA kernel grid dimensions.
|
||||
Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n')
|
||||
Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n')
|
||||
|
||||
// Thread arrangement
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{}));
|
||||
Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{}));
|
||||
|
||||
// Vector dimensions
|
||||
Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{}));
|
||||
@ -215,16 +208,16 @@ int main(int argc, char** argv)
|
||||
// Determine grid and block dimensions
|
||||
//
|
||||
|
||||
dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n'
|
||||
dim3 blockDim(size(shape(thr_layout)));
|
||||
dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n'
|
||||
dim3 blockDim(size(thr_layout));
|
||||
|
||||
//
|
||||
// Launch the kernel
|
||||
//
|
||||
copy_kernel_vectorized<<< gridDim, blockDim >>>(
|
||||
tiled_tensor_S,
|
||||
tiled_tensor_D,
|
||||
thr_layout,
|
||||
tiled_tensor_S,
|
||||
tiled_tensor_D,
|
||||
thr_layout,
|
||||
vec_layout);
|
||||
|
||||
cudaError result = cudaDeviceSynchronize();
|
||||
|
||||
Reference in New Issue
Block a user