diff --git a/CHANGELOG.md b/CHANGELOG.md
index 367d6935..b92893e8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,22 @@
# CUTLASS 2.x
+## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
+ * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
+ * Fast Tensor Core operations:
+ * Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
+ * Tensor Float 32, BFloat16, and double-precision data types
+ * Mixed integer data types (int8, int4, bin1)
+ * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
+ * Features:
+ * SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
+ * Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
+ * Gaussian complex GEMMs using 3m complex multiply algorithm
+ * Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
+ * Policy updates:
+ * [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
+ * Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
+
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1b7bbc48..b6583747 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@@ -32,7 +32,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
-project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
+project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
find_package(Doxygen QUIET)
@@ -84,7 +84,7 @@ endif()
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (NOT CUDA_VERSION VERSION_LESS 7.5)
- list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 50)
+ list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
endif()
if (NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
@@ -98,6 +98,9 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
+if (NOT CUDA_VERSION VERSION_LESS 11.0)
+ list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
+endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
@@ -154,7 +157,7 @@ endif()
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
-set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
+set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
#
# CUTLASS generator cmake configuration
@@ -248,8 +251,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
endif()
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
- list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
- list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
+ list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
+ list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
@@ -271,7 +274,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
set(NVCC_FLAGS)
set(CLANG_FLAGS)
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
- list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
+ list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
set(CODES)
if(CUTLASS_NVCC_EMBED_CUBIN)
list(APPEND CODES sm_${ARCH})
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index fc95674d..f8778b80 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -9,15 +9,17 @@ This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
-Naila Farooqui
+Manish Gupta
Dustyn Blasig
Pradeep Ramani
-Manish Gupta
-Aditya Atluri
+Naila Farooqui
+Piotr Majcher
Paul Springer
-David Tanner
-Scott Yokim
Jin Wang
+Scott Yokim
+Markus Hohnerbach
+Aditya Atluri
+David Tanner
## CONTRIBUTORS
Timothy Costa
@@ -25,12 +27,10 @@ Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
-Markus Hohnerbach
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
-Piotr Majcher
Duane Merrill
Kevin Siu
Markus Tavenrath
diff --git a/CUDA.cmake b/CUDA.cmake
index d1eb4dbc..b8b343a7 100644
--- a/CUDA.cmake
+++ b/CUDA.cmake
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@@ -206,14 +206,14 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
function(cutlass_correct_source_file_language_property)
if(CUDA_COMPILER MATCHES "clang")
foreach(File ${ARGN})
- if(${File} MATCHES ".*\.cu$")
+ if(File MATCHES ".*\.cu$")
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
endif()
endforeach()
endif()
endfunction()
-set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
+set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)
diff --git a/LICENSE.txt b/LICENSE.txt
index 283345b5..64a49d68 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,4 +1,4 @@
-Copyright (c) 2017 - 2019, NVIDIA CORPORATION. All rights reserved.
+Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
diff --git a/README.md b/README.md
index dd1c4c65..c1507c03 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@

-# CUTLASS 2.1
+# CUTLASS 2.2
-_CUTLASS 2.1 - April 2020_
+_CUTLASS 2.2 - June 2020_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@@ -17,14 +17,28 @@ and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for half-precision floating
-point (FP16), single-precision floating point (FP32), double-precision floating
+point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
+single-precision floating point (FP32), double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
-Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations for
+
+Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
-NVIDIA's Volta and Turing architectures.
+NVIDIA's Volta, Turing, and Ampere architectures.
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
+See the [functionality listing](media/docs/functionality.md) for the list of operations
+supported at each level of the execution model hierarchy.
+
+# What's New in CUTLASS 2.2
+
+CUTLASS 2.2 is a significant update to CUTLASS adding:
+
+- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
+- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
+- Deep software pipelines using asynchronous copy
+- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
+
# What's New in CUTLASS 2.1
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
@@ -32,7 +46,6 @@ CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
-
# What's New in CUTLASS 2.0
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
@@ -43,9 +56,6 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to
**See the [CHANGELOG](CHANGELOG.md) for more details.**
-See the [functionality listing](media/docs/functionality.md) for the list of operations
-supported at each level of the execution model hierarchy.
-
# Performance

@@ -53,15 +63,15 @@ supported at each level of the execution model hierarchy.
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
-for large matrix dimensions on an NVIDIA GeForce 2080 Ti and an NVIDIA TitanV
-using CUDA 10.2. Tensor Core operations are implemented using CUDA's
+for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
+using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
# Compatibility
CUTLASS requires a C++11 host compiler and
-performs best when built with the [CUDA 10.2 Toolkit](https://developer.nvidia.com/cuda-toolkit).
-It is compatible with CUDA 9.2, CUDA 10.0, and CUDA 10.1.
+performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
+It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
We have tested the following environments.
@@ -70,27 +80,28 @@ We have tested the following environments.
| Windows 10 | Microsoft Visual Studio 2015|
| | Microsoft Visual Studio 2017|
| Ubuntu 16.04 | GCC 5.4.0 |
-| Ubuntu 18.04 | GCC 7.3.0 |
+| Ubuntu 18.04 | GCC 7.5.0 |
Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
-any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.
+any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
-|**GPU**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
-|---|---|---|
-|NVIDIA GeForce 1080|9.2| |
-|NVIDIA TitanXP|9.2| |
-|NVIDIA Tesla P100|9.2| |
-|NVIDIA Tesla V100|9.2|10.1|
-|NVIDIA TitanV|9.2|10.1|
-|NVIDIA GeForce RTX 2080 TI, 2080, 2070|10.0|10.2|
-|NVIDIA Tesla T4|10.0|10.2|
+|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
+|---|---|---|---|
+|NVIDIA Tesla P100|6.0|9.2| |
+|NVIDIA GeForce 1080|6.1|9.2| |
+|NVIDIA TitanXP|6.1|9.2| |
+|NVIDIA Tesla V100|7.0|9.2|10.1|
+|NVIDIA TitanV|7.0|9.2|10.1|
+|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
+|NVIDIA Tesla T4|7.5|10.0|10.2|
+|NVIDIA A100|8.0|11.0|11.0|
# Documentation
-CUTLASS 2.1 is described in the following documents and the accompanying
+CUTLASS 2.2 is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
@@ -124,7 +135,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
-for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
+for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
@@ -210,6 +221,10 @@ examples/
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
+
+ 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
+
+ 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
```
### Tools
@@ -255,29 +270,32 @@ $ make cutlass_profiler -j
Example command line for profiling SGEMM kernels is as follows:
```
-$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
+$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
=============================
Problem ID: 1
- Provider: CUTLASS
- Operation: cutlass_simt_sgemm_128x128_nn
+ Provider: CUTLASS
+ OperationKind: gemm
+ Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
- Disposition: Passed
- Status: Success
+ Status: Success
+ Verification: ON
+ Disposition: Passed
- Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \
- --split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \
- --stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \
- --max_cc=1024
+ cuBLAS: Passed
- Bytes: 52428800 bytes
- FLOPs: 146064539648 flops
+ Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
+ --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
+ --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
- Runtime: 10.5424 ms
- Memory: 4.63158 GiB/s
+ Bytes: 180355072 bytes
+ FLOPs: 115992428544 flops
- Math: 13854.9 GFLOP/s
+ Runtime: 6.73655 ms
+ Memory: 24.934 GiB/s
+
+ Math: 17218.4 GFLOP/s
```
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
diff --git a/cmake/nop.cu b/cmake/nop.cu
index 571c6c7c..518a582b 100644
--- a/cmake/nop.cu
+++ b/cmake/nop.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/cuBLAS.cmake b/cuBLAS.cmake
index d7f330cf..4c73a1db 100644
--- a/cuBLAS.cmake
+++ b/cuBLAS.cmake
@@ -10,28 +10,35 @@ if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
message(STATUS "cuBLAS Disabled.")
elseif(NOT TARGET cublas)
-
+
find_path(
- _CUBLAS_INCLUDE_DIR cublas.h
- PATHS
- ${CUDA_TOOLKIT_ROOT_DIR}/include
- $ENV{CUBLAS_PATH}/include
- $ENV{CUDA_PATH}/include
- ${CUBLAS_PATH}/include
- /usr/include)
+ _CUBLAS_INCLUDE_DIR
+ NAMES cublas.h
+ HINTS
+ ${CUBLAS_INCLUDE_PATH}
+ ENV CUBLAS_INCLUDE_PATH
+ ${CUBLAS_PATH}
+ ENV CUBLAS_PATH
+ ${CUDA_TOOLKIT_ROOT_DIR}
+ PATH_SUFFIXES
+ include
+ )
find_library(
- _CUBLAS_LIBRARY cublas
+ _CUBLAS_LIBRARY
+ NAMES cublas
HINTS
- ${CUDA_TOOLKIT_ROOT_DIR}/lib64
- ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
- $ENV{CUBLAS_PATH}/lib64
- $ENV{CUBLAS_PATH}/lib/x64
- $ENV{CUDA_PATH}/lib64
- $ENV{CUDA_PATH}/lib/x64
- ${CUBLAS_PATH}/lib64
- ${CUBLAS_PATH}/lib/x64
- /usr/lib/x86_64-linux-gnu)
+ ${CUBLAS_LIBRARY_PATH}
+ ENV CUBLAS_LIBRARY_PATH
+ ${_CUBLAS_INCLUDE_DIR}/..
+ ${CUBLAS_PATH}
+ ENV CUBLAS_PATH
+ ${CUDA_TOOLKIT_ROOT_DIR}
+ PATH_SUFFIXES
+ lib64
+ lib/x64
+ lib
+ )
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
@@ -79,17 +86,20 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
$)
find_library(
- _CUBLASLT_LIBRARY cublasLt
+ _CUBLASLT_LIBRARY
+ NAMES cublasLt
HINTS
- ${CUDA_TOOLKIT_ROOT_DIR}/lib64
- ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
- $ENV{CUBLAS_PATH}/lib64
- $ENV{CUBLAS_PATH}/lib/x64
- $ENV{CUDA_PATH}/lib64
- $ENV{CUDA_PATH}/lib/x64
- ${CUBLAS_PATH}/lib64
- ${CUBLAS_PATH}/lib/x64
- /usr/lib/x86_64-linux-gnu)
+ ${CUBLAS_LIBRARY_PATH}
+ ENV CUBLAS_LIBRARY_PATH
+ ${_CUBLAS_INCLUDE_DIR}/..
+ ${CUBLAS_PATH}
+ ENV CUBLAS_PATH
+ ${CUDA_TOOLKIT_ROOT_DIR}
+ PATH_SUFFIXES
+ lib64
+ lib/x64
+ lib
+ )
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
@@ -106,6 +116,8 @@ if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
add_library(nvidia::cublasLt ALIAS cublasLt)
+ target_link_libraries(cublas INTERFACE cublasLt)
+
endif()
endif()
diff --git a/examples/00_basic_gemm/CMakeLists.txt b/examples/00_basic_gemm/CMakeLists.txt
index 5b833b85..9ae257d9 100644
--- a/examples/00_basic_gemm/CMakeLists.txt
+++ b/examples/00_basic_gemm/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/00_basic_gemm/basic_gemm.cu b/examples/00_basic_gemm/basic_gemm.cu
index 41564632..bda012ab 100644
--- a/examples/00_basic_gemm/basic_gemm.cu
+++ b/examples/00_basic_gemm/basic_gemm.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/01_cutlass_utilities/CMakeLists.txt b/examples/01_cutlass_utilities/CMakeLists.txt
index 2dfa083c..5f22b7b1 100644
--- a/examples/01_cutlass_utilities/CMakeLists.txt
+++ b/examples/01_cutlass_utilities/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/01_cutlass_utilities/cutlass_utilities.cu b/examples/01_cutlass_utilities/cutlass_utilities.cu
index 0b6aa386..d1eaa57f 100644
--- a/examples/01_cutlass_utilities/cutlass_utilities.cu
+++ b/examples/01_cutlass_utilities/cutlass_utilities.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/02_dump_reg_shmem/CMakeLists.txt b/examples/02_dump_reg_shmem/CMakeLists.txt
index 4e9af4fb..5e6112e0 100644
--- a/examples/02_dump_reg_shmem/CMakeLists.txt
+++ b/examples/02_dump_reg_shmem/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/examples/02_dump_reg_shmem/dump_reg_shmem.cu
index 39d58db8..ed712aa8 100644
--- a/examples/02_dump_reg_shmem/dump_reg_shmem.cu
+++ b/examples/02_dump_reg_shmem/dump_reg_shmem.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
*modification, are permitted provided that the following conditions are met:
diff --git a/examples/03_visualize_layout/CMakeLists.txt b/examples/03_visualize_layout/CMakeLists.txt
index 81211df9..5a08c0f8 100644
--- a/examples/03_visualize_layout/CMakeLists.txt
+++ b/examples/03_visualize_layout/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/03_visualize_layout/options.h b/examples/03_visualize_layout/options.h
index c72b1228..dd7de198 100644
--- a/examples/03_visualize_layout/options.h
+++ b/examples/03_visualize_layout/options.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/03_visualize_layout/register_layout.cu b/examples/03_visualize_layout/register_layout.cu
index 655d1f37..0d2b25eb 100644
--- a/examples/03_visualize_layout/register_layout.cu
+++ b/examples/03_visualize_layout/register_layout.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -34,6 +34,8 @@
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor_op_multiplicand_sm70.h"
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
+#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
+
#include "visualize_layout.h"
#include "register_layout.h"
@@ -59,18 +61,40 @@ void RegisterLayouts(std::map
// Integer matrix multiply.int4 8832 TN kblock128
{"TensorOpMultiplicand<4,128>",
new VisualizeLayout>},
+ // Integer matrix multiply.int4 16864 TN kblock256
+ {"TensorOpMultiplicand<4,256>",
+ new VisualizeLayout>},
// Integer matrix multiply 8816 Interleaved-32
{"TensorOpMultiplicand<8,32>",
new VisualizeLayout>},
// Integer matrix multiply 8816 TN kblock64
{"TensorOpMultiplicand<8,64>",
new VisualizeLayout>},
+ {"TensorOpMultiplicand<8,128>",
+ new VisualizeLayout>},
// Matrix Multiply 1688 TN kblock32
{"TensorOpMultiplicand<16,32>",
new VisualizeLayout>},
// Matrix multiply 1688 NT
{"TensorOpMultiplicand<16,64>",
new VisualizeLayout>},
+ // Matrix multiply 1688.TF32 TN kblock16
+ {"TensorOpMultiplicand<32,16>",
+ new VisualizeLayout>},
+ // Matrix multiply 1688.TF32 TN kblock32
+ {"TensorOpMultiplicand<32,32>",
+ new VisualizeLayout>},
+ // Matrix multiply 1688 NT
+ {"TensorOpMultiplicandCongruous<32,32>",
+ new VisualizeLayout<
+ cutlass::layout::TensorOpMultiplicandCongruous<32, 32>>},
+ // Matrix multiply 884 NT
+ {"TensorOpMultiplicandCongruous<64,16>",
+ new VisualizeLayout<
+ cutlass::layout::TensorOpMultiplicandCongruous<64, 16>>},
+ // Matrix multiply 884 TN
+ {"TensorOpMultiplicand64bCrosswise",
+ new VisualizeLayout},
{"TensorOpMultiplicandCongruous<128,4>",
new VisualizeLayout<
cutlass::layout::TensorOpMultiplicandCongruous<128, 4>>},
@@ -82,7 +106,7 @@ void RegisterLayouts(std::map
cutlass::layout::VoltaTensorOpMultiplicandCongruous<16>>},
{"VoltaTensorOpMultiplicandCrosswise<16,32>",
new VisualizeLayout<
- cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>},
+ cutlass::layout::VoltaTensorOpMultiplicandCrosswise<16, 32>>}
};
for (auto layout : layout_pairs) {
diff --git a/examples/03_visualize_layout/register_layout.h b/examples/03_visualize_layout/register_layout.h
index fee911f7..1518e433 100644
--- a/examples/03_visualize_layout/register_layout.h
+++ b/examples/03_visualize_layout/register_layout.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/03_visualize_layout/visualize_layout.cpp b/examples/03_visualize_layout/visualize_layout.cpp
index 8908d2c1..a0f27181 100644
--- a/examples/03_visualize_layout/visualize_layout.cpp
+++ b/examples/03_visualize_layout/visualize_layout.cpp
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -65,14 +65,26 @@ void print_usage(std::ostream &out) {
"--extent=64,64 --vectorize=32 --output-shape=256,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<4,128>\" "
"--extent=128,32 --vectorize=32 --output-shape=256,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicand<4,256>\" "
+ "--extent=256,16 --vectorize=32 --output-shape=256,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,32>\" "
"--extent=32,64 --vectorize=16 --output-shape=128,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<8,64>\" "
"--extent=64,32 --vectorize=16 --output-shape=128,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicand<8,128>\" "
+ "--extent=128,16 --vectorize=16 --output-shape=128,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,32>\" "
"--extent=32,32 --vectorize=8 --output-shape=64,4\n"
<< "$ 03_visualize_layout \"TensorOpMultiplicand<16,64>\" "
"--extent=64,16 --vectorize=8 --output-shape=64,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicand<32,16>\" "
+ "--extent=16,32 --vectorize=4 --output-shape=32,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicand<32,32>\" "
+ "--extent=32,16 --vectorize=4 --output-shape=32,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<32,32>\" "
+ "--extent=32,16 --vectorize=4 --output-shape=32,4\n"
+ << "$ 03_visualize_layout \"TensorOpMultiplicandCongruous<64, 16>\" "
+ "--extent=16,16 --vectorize=2 --output-shape=16,4\n"
<< "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCrosswise<16,32>\" "
"--extent=32,64 --vectorize=4 --output-shape=64,4\n"
<< "$ 03_visualize_layout \"VotlaTensorOpMultiplicandCongruous<16>\" "
diff --git a/examples/03_visualize_layout/visualize_layout.h b/examples/03_visualize_layout/visualize_layout.h
index 031916c7..4093d277 100644
--- a/examples/03_visualize_layout/visualize_layout.h
+++ b/examples/03_visualize_layout/visualize_layout.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/04_tile_iterator/CMakeLists.txt b/examples/04_tile_iterator/CMakeLists.txt
index cef15624..cd32e228 100644
--- a/examples/04_tile_iterator/CMakeLists.txt
+++ b/examples/04_tile_iterator/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu
index e6315760..5c56f33b 100644
--- a/examples/04_tile_iterator/tile_iterator.cu
+++ b/examples/04_tile_iterator/tile_iterator.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/05_batched_gemm/CMakeLists.txt b/examples/05_batched_gemm/CMakeLists.txt
index 6c9bf504..6cd0ca8d 100644
--- a/examples/05_batched_gemm/CMakeLists.txt
+++ b/examples/05_batched_gemm/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/05_batched_gemm/batched_gemm.cu b/examples/05_batched_gemm/batched_gemm.cu
index d1fecda6..a9d8a9c6 100644
--- a/examples/05_batched_gemm/batched_gemm.cu
+++ b/examples/05_batched_gemm/batched_gemm.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/examples/06_splitK_gemm/CMakeLists.txt b/examples/06_splitK_gemm/CMakeLists.txt
index 750c6205..7b30ae16 100644
--- a/examples/06_splitK_gemm/CMakeLists.txt
+++ b/examples/06_splitK_gemm/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/06_splitK_gemm/splitk_gemm.cu b/examples/06_splitK_gemm/splitk_gemm.cu
index 5fb513cb..f0e1d578 100644
--- a/examples/06_splitK_gemm/splitk_gemm.cu
+++ b/examples/06_splitK_gemm/splitk_gemm.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -39,7 +39,7 @@ inner product (1/16th of output), they accumulate to single output matrix.
Writing a single high performance matrix multiplication kernel is hard but do-able. Whereas writing
high performance kernels at scale which works for multiple problem sizes with good abstractions is
-really hard. CUTLASS solves this problem by providing simplified abstractions (knobs) to compose
+really hard. CUTLASS solves this problem by providing simplified abstractions to compose
multiple sections of gemm kernel. When used properly, the kernels can hit peak performance of GPU
easily.
@@ -144,7 +144,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
-using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@@ -172,17 +172,7 @@ using Gemm = cutlass::gemm::device::GemmSplitKParallel;
-int main() {
-
- //
- // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
- //
- // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
- //
- if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
- std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
- return -1;
- }
+int run() {
cudaDeviceProp props;
@@ -316,11 +306,30 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
- std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
- tensor_ref_d.host_view())
- ? "Passed"
- : "Failed")
- << std::endl;
+ bool passed = cutlass::reference::host::TensorEquals(
+ tensor_d.host_view(),
+ tensor_ref_d.host_view());
- CUTLASS_CHECK(status);
+ std::cout << (passed ? "Passed" : "Failed") << std::endl;
+
+ return (passed ? 0 : -1);
}
+
+int main() {
+
+ //
+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
+ //
+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
+ //
+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
+
+ // Returning zero, so this test passes when built with older CUDA Toolkits. Its action are no-op.
+ return 0;
+ }
+ else {
+ return run();
+ }
+}
+
diff --git a/examples/07_volta_tensorop_gemm/CMakeLists.txt b/examples/07_volta_tensorop_gemm/CMakeLists.txt
index 56dfce9e..82e81722 100644
--- a/examples/07_volta_tensorop_gemm/CMakeLists.txt
+++ b/examples/07_volta_tensorop_gemm/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu
index 447cc1cc..208c4f64 100644
--- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu
+++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -156,7 +156,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M =
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4
// This code section describes how threadblocks are scheduled on GPU
-using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes ?
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@@ -188,15 +188,7 @@ using Gemm = cutlass::gemm::device::Gemm;
-int main() {
-
- // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
- //
- // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
- if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
- std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
- return -1;
- }
+int run() {
cudaDeviceProp props;
@@ -223,7 +215,7 @@ int main() {
cutlass::HostTensor tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor tensor_b(
- problem_size.nk()); // <- Create matrix B with dimensions N x K
+ problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor tensor_d(
@@ -326,12 +318,28 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
- std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
- tensor_ref_d.host_view())
- ? "Passed"
- : "Failed")
- << std::endl;
+ bool passed = cutlass::reference::host::TensorEquals(
+ tensor_d.host_view(),
+ tensor_ref_d.host_view());
- CUTLASS_CHECK(status);
- return 0;
+ std::cout << (passed ? "Passed" : "Failed") << std::endl;
+
+ return (passed ? 0 : -1);
}
+
+int main() {
+
+ // Volta Tensor Core operations exposed with mma.sync are first available in CUDA 10.1.
+ //
+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
+ std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
+
+ // Returning zero when built on older Toolkits so tests pass. The actions of this SDK example are no-op.
+ return 0;
+ }
+ else {
+ return run();
+ }
+}
+
diff --git a/examples/08_turing_tensorop_gemm/CMakeLists.txt b/examples/08_turing_tensorop_gemm/CMakeLists.txt
index 9e011a1e..b4e4fe82 100644
--- a/examples/08_turing_tensorop_gemm/CMakeLists.txt
+++ b/examples/08_turing_tensorop_gemm/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu
index 3440d82f..d7ba8331 100644
--- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu
+++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -150,12 +150,12 @@ using SmArch = cutlass::arch::Sm75;
using ShapeMMAThreadBlock =
cutlass::gemm::GemmShape<128, 256, 64>; // <- threadblock tile M = 128, N = 256, K = 64
// This code section describes tile size a warp will compute
-using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 16
+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = 64, N = 64, K = 64
// This code section describes the size of MMA op
using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16
// This code section describes how threadblocks are scheduled on GPU
-using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; // <- ??
+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes the epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
@@ -186,7 +186,7 @@ using Gemm = cutlass::gemm::device::Gemm;
-int main() {
+int run() {
// Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 10.2.
@@ -222,7 +222,7 @@ int main() {
cutlass::HostTensor tensor_a(
problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor tensor_b(
- problem_size.nk()); // <- Create matrix B with dimensions N x K
+ problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor tensor_c(
problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor tensor_d(
@@ -325,12 +325,28 @@ int main() {
tensor_ref_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
- std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
- tensor_ref_d.host_view())
- ? "Passed"
- : "Failed")
- << std::endl;
+ bool passed = cutlass::reference::host::TensorEquals(
+ tensor_d.host_view(),
+ tensor_ref_d.host_view());
- CUTLASS_CHECK(status);
- return 0;
+ std::cout << (passed ? "Passed" : "Failed") << std::endl;
+
+ return (passed ? 0 : -1);
}
+
+int main() {
+ // Turing Tensor Core operations exposed with mma.sync and ldmatrix are first available
+ // in CUDA 10.2.
+ //
+ // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples.
+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
+
+ // Returning zero so this test passes when built on older Toolkits.
+ return 0;
+ }
+ else {
+ return run();
+ }
+}
+
diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu
index 7fc92870..b7318b99 100644
--- a/examples/10_planar_complex/planar_complex.cu
+++ b/examples/10_planar_complex/planar_complex.cu
@@ -500,7 +500,9 @@ int main(int argc, char const **args) {
if (props.major < 7) {
std::cerr << "Volta Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
- return -1;
+
+ // Returning zero so this test passes on older architectures even though its actions are no-op.
+ return 0;
}
else if (props.major == 7 && props.minor <= 2) {
//
@@ -508,7 +510,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
- return -1;
+
+ // Returning zero so this test passes on older Toolkits even though its actions are no-op.
+ return 0;
}
}
else if (props.major == 7 && props.minor >= 5) {
@@ -517,7 +521,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
- return -1;
+
+ // Returning zero so this test passes on older Toolkits even though its actions are no-op.
+ return 0;
}
}
diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu
index 3003a900..6a027053 100644
--- a/examples/11_planar_complex_array/planar_complex_array.cu
+++ b/examples/11_planar_complex_array/planar_complex_array.cu
@@ -560,7 +560,9 @@ int main(int argc, char const **args) {
if (props.major < 7) {
std::cerr << "Tensor Core operations must be run on a machine with compute capability at least 70."
<< std::endl;
- return -1;
+
+ // Returning zero so this passes on older architectures. Its actions are no-op.
+ return 0;
}
else if (props.major == 7 && props.minor <= 2) {
//
@@ -568,7 +570,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))) {
std::cerr << "Volta Tensor Core operations must be compiled with CUDA 10.1 Toolkit or later." << std::endl;
- return -1;
+
+ // Returning zero so this passes on older Toolkits. Its actions are no-op.
+ return 0;
}
}
else if (props.major == 7 && props.minor >= 5) {
@@ -577,7 +581,9 @@ int main(int argc, char const **args) {
//
if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
- return -1;
+
+ // Returning zero so this passes on older Toolkits. Its actions are no-op.
+ return 0;
}
}
diff --git a/examples/12_gemm_bias_relu/CMakeLists.txt b/examples/12_gemm_bias_relu/CMakeLists.txt
new file mode 100644
index 00000000..fb78d77f
--- /dev/null
+++ b/examples/12_gemm_bias_relu/CMakeLists.txt
@@ -0,0 +1,27 @@
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification, are permitted
+# provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright notice, this list of
+# conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice, this list of
+# conditions and the following disclaimer in the documentation and/or other materials
+# provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+# to endorse or promote products derived from this software without specific prior written
+# permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+cutlass_example_add_executable(
+ 12_gemm_bias_relu
+ gemm_bias_relu.cu
+ )
+
diff --git a/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/examples/12_gemm_bias_relu/gemm_bias_relu.cu
new file mode 100644
index 00000000..7faaa98a
--- /dev/null
+++ b/examples/12_gemm_bias_relu/gemm_bias_relu.cu
@@ -0,0 +1,282 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+*/
+
+#include
+#include
+
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/device/gemm.h"
+#include "cutlass/epilogue/thread/linear_combination_relu.h"
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/tensor_view_io.h"
+#include "helper.h"
+
+// The code section below describes datatype for input, output matrices and computation between
+// elements in input matrices.
+using ElementAccumulator = float; // <- data type of accumulator
+using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
+using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
+using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
+using ElementOutput = float; // <- data type of elements in output matrix D
+
+// The code section below describes matrix layout of input and output matrices. Column Major for
+// Matrix A, Row Major for Matrix B and Row Major for Matrix C
+using LayoutInputA = cutlass::layout::ColumnMajor;
+using LayoutInputB = cutlass::layout::ColumnMajor;
+using LayoutOutput = cutlass::layout::RowMajor;
+
+// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
+using MMAOp = cutlass::arch::OpClassTensorOp;
+
+// This code section describes CUDA SM architecture number
+using SmArch = cutlass::arch::Sm75;
+
+// This code section describes the tile size a thread block will compute
+using ShapeMMAThreadBlock =
+ cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32
+// This code section describes tile size a warp will compute
+using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32
+// This code section describes the size of MMA op
+using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4
+
+// This code section describes how threadblocks are scheduled on GPU
+using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
+
+// Define the epilogue operation as LinearCombinationRelu. This is approximately equal to
+//
+// d_ij = max(0, alpha * sum_k(a_ik * b_kj) + beta * c_ij )
+//
+using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput, // <- data type of output matrix
+ 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per
+ // vectorized memory access. For half
+ // precision, it's 8 elements. This becomes
+ // the vector width of math instructions in
+ // epilogue too
+ ElementAccumulator, // <- data type of accumulator
+ ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
+
+// Number of pipelines you want to use
+constexpr int NumStages = 2;
+
+using Gemm = cutlass::gemm::device::Gemm;
+
+int run() {
+
+ cudaDeviceProp props;
+
+ cudaError_t error = cudaGetDeviceProperties(&props, 0);
+ if (error != cudaSuccess) {
+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
+ return -1;
+ }
+
+ if (!(props.major * 10 + props.minor >= 75)) {
+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
+ << std::endl;
+ // Returning zero so this test passes on older Toolkits. Its actions are no-op.
+ return 0;
+ }
+
+ const int length_m = 5120;
+ const int length_n = 4096;
+ const int length_k = 4096;
+
+ // Create a tuple of problem size for matrix multiplication
+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
+
+ // Initialize tensors using CUTLASS helper functions
+ cutlass::HostTensor tensor_a(
+ problem_size.mk()); // <- Create matrix A with dimensions M x K
+ cutlass::HostTensor tensor_b(
+ problem_size.nk()); // <- Create matrix B with dimensions N x K
+
+ cutlass::HostTensor tensor_c_bias(
+ {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1
+
+ cutlass::HostTensor tensor_d(
+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
+ // CUTLASS kernel
+ cutlass::HostTensor tensor_ref_d(
+ problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from
+ // reference kernel
+
+ // Fill input and output matrices on host using CUTLASS helper functions
+ cutlass::reference::host::TensorFillRandomUniform(
+ tensor_a.host_view(),
+ 1,
+ ElementInputA(4),
+ ElementInputA(-4),
+ 0); // <- Fill matrix A on host with uniform-distribution random data
+ cutlass::reference::host::TensorFillRandomUniform(
+ tensor_b.host_view(),
+ 1,
+ ElementInputB(4),
+ ElementInputB(-4),
+ 0); // <- Fill matrix B on host with uniform-distribution random data
+ cutlass::reference::host::TensorFillRandomUniform(
+ tensor_c_bias.host_view(),
+ 1,
+ ElementOutput(4),
+ ElementOutput(-4),
+ 0); // <- Fill matrix C on host with uniform-distribution random data
+ cutlass::reference::host::TensorFill(
+ tensor_d.host_view()); // <- fill matrix D on host with zeros
+ cutlass::reference::host::TensorFill(
+ tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros
+
+ // Copy data from host to GPU
+ tensor_a.sync_device();
+ tensor_b.sync_device();
+ tensor_c_bias.sync_device();
+ tensor_d.sync_device();
+ tensor_ref_d.sync_device();
+
+ // Initialize alpha and beta for dot product computation
+ ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
+ ElementComputeEpilogue beta = ElementComputeEpilogue(0);
+
+ // Split K dimension into 1 partitions
+ int split_k_slices = 1;
+
+ // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch
+ // instantiated CUTLASS kernel
+ typename Gemm::Arguments arguments{
+ problem_size, // <- problem size of matrix multiplication
+ tensor_a.device_ref(), // <- reference to matrix A on device
+ tensor_b.device_ref(), // <- reference to matrix B on device
+
+ {tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM
+ // to project away the N dimension by setting the stride to zero.
+
+ tensor_d.device_ref(), // <- reference to matrix D on device
+ {alpha, beta}, // <- tuple of alpha and beta
+ split_k_slices}; // <- k-dimension split factor
+
+ // Using the arguments, query for extra workspace required for matrix multiplication computation
+ size_t workspace_size = Gemm::get_workspace_size(arguments);
+
+ // Allocate workspace memory
+ cutlass::device_memory::allocation workspace(workspace_size);
+
+ // Instantiate CUTLASS kernel depending on templates
+ Gemm gemm_op;
+
+ // Initialize CUTLASS kernel with arguments and workspace pointer
+ cutlass::Status status = gemm_op.initialize(arguments, workspace.get());
+ CUTLASS_CHECK(status);
+
+ // Launch initialized CUTLASS kernel
+ status = gemm_op();
+ CUTLASS_CHECK(status);
+
+ //
+ // Create instantiation for device reference gemm kernel
+ //
+
+ cutlass::reference::device::Gemm
+ gemm_device_reference;
+
+ // Launch device reference to compute strictly the product A * B
+ gemm_device_reference(
+ problem_size,
+ alpha,
+ tensor_a.device_ref(),
+ tensor_b.device_ref(),
+ 0,
+ tensor_c_bias.device_ref(),
+ tensor_ref_d.device_ref());
+
+ // Wait for kernels to finish
+ cudaDeviceSynchronize();
+
+ // Copy output data from CUTLASS and reference kernel to host for comparison
+ tensor_d.sync_host();
+ tensor_ref_d.sync_host();
+
+ // Compute bias + relu in host code
+ for (int i = 0; i < problem_size.m(); ++i) {
+ for (int j = 0; j < problem_size.n(); ++j) {
+ tensor_ref_d.at({i, j}) = std::max(
+ ElementOutput(0),
+ ElementOutput(tensor_ref_d.at({i, j}) + beta * tensor_c_bias.at({i, 0}))
+ );
+ }
+ }
+
+ // Check if output from CUTLASS kernel and reference kernel are equal or not
+ std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(),
+ tensor_ref_d.host_view())
+ ? "Passed"
+ : "Failed")
+ << std::endl;
+
+ CUTLASS_CHECK(status);
+ return 0;
+}
+
+int main() {
+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
+ //
+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
+
+ // Returning zero so this test passes on older Toolkits. Its actions are no-op.
+ return 0;
+ }
+ else {
+ return run();
+ }
+}
+
diff --git a/examples/13_fused_two_gemms/CMakeLists.txt b/examples/13_fused_two_gemms/CMakeLists.txt
new file mode 100644
index 00000000..ba51537c
--- /dev/null
+++ b/examples/13_fused_two_gemms/CMakeLists.txt
@@ -0,0 +1,33 @@
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification, are permitted
+# provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright notice, this list of
+# conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice, this list of
+# conditions and the following disclaimer in the documentation and/or other materials
+# provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+# to endorse or promote products derived from this software without specific prior written
+# permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+cutlass_example_add_executable(
+ 13_fused_two_gemms
+ fused_gemm.cu
+ )
+
+target_include_directories(
+ 13_fused_two_gemms
+ PRIVATE
+ .
+ )
+
diff --git a/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h b/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h
new file mode 100644
index 00000000..10a0d4bf
--- /dev/null
+++ b/examples/13_fused_two_gemms/b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h
@@ -0,0 +1,190 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include
+
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/device/gemm.h"
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/tensor_view_io.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/gemm.h"
+
+#include "device/b2b_gemm.h"
+#include "b2b_gemm_run.h"
+
+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
+
+////////////////////////////////////////////////////////////////////////////////
+
+void run_nonfused_gemm_f16() {
+
+ using ElementOutput = cutlass::half_t;
+ using ElementAccumulator = cutlass::half_t;
+ using ElementCompute = cutlass::half_t;
+
+ cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
+ cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
+ ElementCompute alpha0 = ElementCompute(2);
+ ElementCompute beta0 = ElementCompute(0);
+ ElementCompute alpha1 = ElementCompute(2);
+ ElementCompute beta1 = ElementCompute(1);
+
+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
+ using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
+
+ using Gemm0 = cutlass::gemm::device::Gemm<
+ cutlass::half_t,
+ cutlass::layout::RowMajor,
+ cutlass::half_t,
+ cutlass::layout::ColumnMajor,
+ ElementOutput,
+ cutlass::layout::RowMajor,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape0,
+ WarpShape0,
+ InstructionShape,
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 128 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+ using Gemm1 = cutlass::gemm::device::Gemm<
+ cutlass::half_t,
+ cutlass::layout::RowMajor,
+ cutlass::half_t,
+ cutlass::layout::ColumnMajor,
+ ElementOutput,
+ cutlass::layout::RowMajor,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape1,
+ WarpShape1,
+ InstructionShape,
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 128 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+
+ B2bNonFusedGemmRun nonFusedGemm;
+
+ std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n";
+ bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
+ if(pass)
+ std::cout << "Pass\n";
+ else
+ std::cout << "Fail\n";
+}
+
+void run_fused_gemm_f16() {
+
+ using ElementOutput = cutlass::half_t;
+ using ElementAccumulator = cutlass::half_t;
+ using ElementCompute = cutlass::half_t;
+
+ cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
+ cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
+ ElementCompute alpha0 = ElementCompute(2);
+ ElementCompute beta0 = ElementCompute(0);
+ ElementCompute alpha1 = ElementCompute(2);
+ ElementCompute beta1 = ElementCompute(1);
+
+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>;
+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 32>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
+
+ using EpilogueOutputOp0 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ InstructionShape::kM * InstructionShape::kN / 32,
+ ElementAccumulator,
+ ElementCompute
+ >;
+
+ using EpilogueOutputOp1 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 128 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >;
+
+
+
+ using B2bGemm = cutlass::gemm::device::B2bGemm<
+ cutlass::half_t,
+ cutlass::layout::RowMajor,
+ cutlass::half_t,
+ cutlass::layout::ColumnMajor,
+ ElementOutput,
+ cutlass::layout::RowMajor,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ EpilogueOutputOp0,
+ EpilogueOutputOp1,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+
+ B2bFusedGemmRun fusedGemm;
+
+ std::cout << "Running Fused back-to-back FP16 TN GEMMs...\n";
+ bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
+ if(passed)
+ std::cout << "Pass\n";
+ else
+ std::cout << "Fail\n";
+
+}
+////////////////////////////////////////////////////////////////////////////////
+
+#endif //#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
diff --git a/examples/13_fused_two_gemms/b2b_gemm_run.h b/examples/13_fused_two_gemms/b2b_gemm_run.h
new file mode 100644
index 00000000..053064d7
--- /dev/null
+++ b/examples/13_fused_two_gemms/b2b_gemm_run.h
@@ -0,0 +1,608 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include
+#include
+#include
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/tensor_view_io.h"
+#include "cutlass/util/distribution.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/tensor_norm.h"
+#include "cutlass/util/reference/device/gemm.h"
+#include "cutlass/util/reference/device/tensor_relu.h"
+
+#include "helper.h"
+
+#define CHECK_GT(val1, val2) \
+ if((val1) <= (val2)) \
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
+#define CHECK_TRUE(val) \
+ if(!(val)) \
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
+
+////////////////////////////////////////////////////////////////////////////////
+
+template
+struct B2bNonFusedGemmRun
+{
+
+ using Gemm0 = Gemm0_;
+ using Gemm1 = Gemm1_;
+ using ElementAccumulator = typename Gemm0::ElementAccumulator;
+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
+
+ /// Initialization
+ cutlass::Distribution::Kind init_A;
+ cutlass::Distribution::Kind init_B;
+ cutlass::Distribution::Kind init_C;
+ uint64_t seed;
+
+ //
+ // Methods
+ //
+
+ B2bNonFusedGemmRun(
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ uint64_t seed_ = 2080
+ ):
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
+
+ /// Helper to initialize a tensor view
+ template
+ bool initialize_tensor(
+ cutlass::TensorView view,
+ cutlass::Distribution::Kind dist_kind,
+ uint64_t seed) {
+
+ if (dist_kind == cutlass::Distribution::Uniform) {
+
+ cutlass::reference::host::TensorFillRandomUniform(
+ view, seed, 2, -2, 0);
+ }
+ else if (dist_kind == cutlass::Distribution::Identity) {
+
+ cutlass::reference::host::TensorFillIdentity(view);
+ }
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
+
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
+ }
+ else if (dist_kind == cutlass::Distribution::Sequential) {
+
+ cutlass::reference::host::BlockFillSequential(
+ view.data(), view.capacity());
+ }
+ else {
+ // TODO: Implement the rest
+ std::cerr << "Not implemented\n";
+ return false;
+ }
+
+ return true;
+ }
+
+
+
+
+ /// Executes one test
+ bool run(
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
+ ElementCompute beta0 = ElementCompute(0),
+ ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute beta1 = ElementCompute(0),
+ bool relu = true) {
+
+ //
+ // Allocate the GEMM workspace
+ //
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementA,
+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementB,
+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementB,
+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
+
+
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
+
+ cutlass::reference::host::TensorFill(
+ tensor_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ tensor_D1.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D1.host_view());
+
+ tensor_A0.sync_device();
+ tensor_B0.sync_device();
+ tensor_C0.sync_device();
+ tensor_D0.sync_device();
+ tensor_B1.sync_device();
+ tensor_C1.sync_device();
+ tensor_D1.sync_device();
+ reference_D0.sync_device();
+ reference_D1.sync_device();
+
+ //
+ // Initialize the GEMM operator
+ //
+
+ typename Gemm0::Arguments arguments_0{
+ problem_size_0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ tensor_C0.device_ref(),
+ tensor_D0.device_ref(),
+ {alpha0, beta0}
+ };
+
+ typename Gemm1::Arguments arguments_1{
+ problem_size_1,
+ tensor_D0.device_ref(),
+ tensor_B1.device_ref(),
+ tensor_C1.device_ref(),
+ tensor_D1.device_ref(),
+ {alpha1, beta1}
+ };
+
+
+ Gemm0 gemm_op_0;
+ Gemm1 gemm_op_1;
+
+ cutlass::Status status = gemm_op_0.initialize(arguments_0);
+
+ CUTLASS_CHECK(status);
+
+ status = gemm_op_1.initialize(arguments_1);
+
+ CUTLASS_CHECK(status);
+ //
+ // Run the GEMM
+ //
+
+ cudaEvent_t start, stop1, stop2;
+ cudaEventCreate(&start);
+ cudaEventCreate(&stop1);
+ cudaEventCreate(&stop2);
+
+ cudaEventRecord(start);
+
+ for(int i = 0; i < 100; i++) {
+ status = gemm_op_0();
+
+ CUTLASS_CHECK(status);
+ }
+ cudaEventRecord(stop1);
+ for(int i = 0; i < 100; i++) {
+
+ status = gemm_op_1();
+
+ CUTLASS_CHECK(status);
+ }
+
+ cudaEventRecord(stop2);
+ cudaDeviceSynchronize();
+ float gemm0Time, gemm1Time, totalTime;
+ cudaEventElapsedTime(&gemm0Time, start, stop1);
+ cudaEventElapsedTime(&gemm1Time, stop1, stop2);
+ cudaEventElapsedTime(&totalTime, start, stop2);
+ std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
+ std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
+ std::cout << "total time " << totalTime / 100.0 << " ms\n";
+
+ tensor_D0.sync_host();
+ tensor_D1.sync_host();
+
+ //
+ // Verify
+ //
+ cutlass::reference::device::Gemm<
+ typename Gemm0::ElementA, typename Gemm0::LayoutA,
+ typename Gemm0::ElementB, typename Gemm0::LayoutB,
+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
+ ElementAccumulator, typename Gemm0::Operator>
+ reference_gemm_0;
+
+ cutlass::reference::device::Gemm<
+ typename Gemm1::ElementA, typename Gemm1::LayoutA,
+ typename Gemm1::ElementB, typename Gemm1::LayoutB,
+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
+ ElementAccumulator, typename Gemm1::Operator>
+ reference_gemm_1;
+
+ reference_gemm_0(
+ problem_size_0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
+ tensor_C0.device_ref(),
+ reference_D0.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ }
+
+ reference_gemm_1(
+ problem_size_1,
+ alpha1,
+ reference_D0.device_ref(),
+ tensor_B1.device_ref(),
+ beta1,
+ tensor_C1.device_ref(),
+ reference_D1.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ }
+
+ // Wait for kernels to finish
+ cudaDeviceSynchronize();
+ reference_D0.sync_host();
+ reference_D1.sync_host();
+
+
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
+
+ bool passed = cutlass::reference::host::TensorEquals(
+ reference_D1.host_view(),
+ tensor_D1.host_view());
+
+ CHECK_TRUE(passed);
+ if (!passed) {
+
+ std::stringstream fname;
+
+ fname << "error_B2bGemm_device_nonfused.txt";
+ std::cerr << "Dumping results in " << fname.str() << "\n";
+
+ std::ofstream file(fname.str());
+
+ file
+ << "A0 =\n" << tensor_A0.host_view()
+ << "\nB0 =\n" << tensor_B0.host_view()
+ << "\nC0 =\n" << tensor_C0.host_view()
+ << "\nD0 =\n" << tensor_D0.host_view()
+ << "\nB1 =\n" << tensor_B1.host_view()
+ << "\nC1 =\n" << tensor_C1.host_view()
+ << "\n\nReference =\n" << reference_D1.host_view()
+ << "\nComputed =\n" << tensor_D1.host_view();
+ }
+
+ return passed;
+ }
+};
+
+template
+struct B2bFusedGemmRun
+{
+
+ using B2bGemm = B2bGemm_;
+ using ElementAccumulator = typename B2bGemm::ElementAccumulator;
+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
+
+ /// Initialization
+ cutlass::Distribution::Kind init_A;
+ cutlass::Distribution::Kind init_B;
+ cutlass::Distribution::Kind init_C;
+ uint64_t seed;
+
+ //
+ // Methods
+ //
+
+ B2bFusedGemmRun(
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ uint64_t seed_ = 2080
+ ):
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
+
+ /// Helper to initialize a tensor view
+ template
+ bool initialize_tensor(
+ cutlass::TensorView view,
+ cutlass::Distribution::Kind dist_kind,
+ uint64_t seed) {
+
+ if (dist_kind == cutlass::Distribution::Uniform) {
+
+ cutlass::reference::host::TensorFillRandomUniform(
+ view, seed, 2, -2, 0);
+ }
+ else if (dist_kind == cutlass::Distribution::Identity) {
+
+ cutlass::reference::host::TensorFillIdentity(view);
+ }
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
+
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
+ }
+ else if (dist_kind == cutlass::Distribution::Sequential) {
+
+ cutlass::reference::host::BlockFillSequential(
+ view.data(), view.capacity());
+ }
+ else {
+ // TODO: Implement the rest
+ std::cerr << "Not implemented\n";
+ return false;
+ }
+
+ return true;
+ }
+
+
+
+
+ /// Executes one test
+ bool run(
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
+ ElementCompute beta0 = ElementCompute(0),
+ ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute beta1 = ElementCompute(0),
+ bool relu = true) {
+
+ //
+ // Allocate the GEMM workspace
+ //
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementA,
+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
+
+// cutlass::HostTensor<
+// typename B2bGemm::ElementC,
+// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
+
+
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
+
+ cutlass::reference::host::TensorFill(
+ tensor_D1.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D1.host_view());
+
+ tensor_A0.sync_device();
+ tensor_B0.sync_device();
+ tensor_C0.sync_device();
+ tensor_B1.sync_device();
+ tensor_C1.sync_device();
+ tensor_D1.sync_device();
+ reference_D0.sync_device();
+ reference_D1.sync_device();
+
+ //
+ // Initialize the GEMM operator
+ //
+
+ typename B2bGemm::Arguments arguments{
+ problem_size_0,
+ problem_size_1,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ tensor_C0.device_ref(),
+ tensor_B1.device_ref(),
+ tensor_C1.device_ref(),
+ tensor_D1.device_ref(),
+ {alpha0, beta0},
+ {alpha1, beta1},
+ };
+
+ B2bGemm b2b_gemm_op;
+
+ cutlass::Status status = b2b_gemm_op.initialize(arguments);
+
+ CUTLASS_CHECK(status);
+
+ //
+ // Run the GEMM
+ //
+
+ cudaEvent_t start, stop;
+ cudaEventCreate(&start);
+ cudaEventCreate(&stop);
+
+ cudaEventRecord(start);
+
+ for(int i = 0; i < 100; i++) {
+ status = b2b_gemm_op();
+
+ CUTLASS_CHECK(status);
+ }
+
+ cudaEventRecord(stop);
+ cudaDeviceSynchronize();
+ float gemmTime;
+ cudaEventElapsedTime(&gemmTime, start, stop);
+ std::cout << "time " << gemmTime / 100.0 << " ms\n";
+
+ //tensor_D0.sync_host();
+ tensor_D1.sync_host();
+
+ //
+ // Verify
+ //
+ cutlass::reference::device::Gemm<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
+ ElementAccumulator, typename B2bGemm::Operator>
+ reference_gemm_0, reference_gemm_1;
+
+ reference_gemm_0(
+ problem_size_0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
+ tensor_C0.device_ref(),
+ reference_D0.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ }
+
+ reference_gemm_1(
+ problem_size_1,
+ alpha1,
+ reference_D0.device_ref(),
+ tensor_B1.device_ref(),
+ beta1,
+ tensor_C1.device_ref(),
+ reference_D1.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ }
+
+ cudaDeviceSynchronize();
+ reference_D0.sync_host();
+ reference_D1.sync_host();
+
+
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
+
+ bool passed = cutlass::reference::host::TensorEquals(
+ reference_D1.host_view(),
+ tensor_D1.host_view());
+
+ CHECK_TRUE(passed);
+ if (!passed) {
+
+ std::stringstream fname;
+
+ fname << "error_B2bGemm_device_fused.txt";
+ std::cerr << "Dumping results in " << fname.str() << "\n";
+
+ std::ofstream file(fname.str());
+
+ file
+ << "A0 =\n" << tensor_A0.host_view()
+ << "\nB0 =\n" << tensor_B0.host_view()
+ << "\nC0 =\n" << tensor_C0.host_view()
+// << "\nD0 =\n" << tensor_D0.host_view()
+ << "\nB1 =\n" << tensor_B1.host_view()
+ << "\nC1 =\n" << tensor_C1.host_view()
+ << "\n\nReference =\n" << reference_D1.host_view()
+ << "\nComputed =\n" << tensor_D1.host_view();
+ }
+
+ return passed;
+ }
+
+};
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h
new file mode 100644
index 00000000..1c3f15c2
--- /dev/null
+++ b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h
@@ -0,0 +1,190 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+#include
+
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/device/gemm.h"
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/tensor_view_io.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/gemm.h"
+
+#include "device/b2b_gemm.h"
+#include "b2b_interleaved_gemm_run.h"
+
+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
+
+////////////////////////////////////////////////////////////////////////////////
+
+void run_nonfused_gemm_s8() {
+
+ using ElementOutput = int8_t;
+ using ElementAccumulator = int32_t;
+ using ElementCompute = float;
+
+ cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
+ cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
+ ElementCompute alpha0 = ElementCompute(2);
+ ElementCompute beta0 = ElementCompute(0);
+ ElementCompute alpha1 = ElementCompute(2);
+ ElementCompute beta1 = ElementCompute(1);
+
+ using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>;
+ using WarpShape0 = cutlass::gemm::GemmShape<32, 32, 64>;
+ using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 64>;
+ using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 64>;
+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
+
+ using Gemm0 = cutlass::gemm::device::Gemm<
+ int8_t,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ int8_t,
+ cutlass::layout::RowMajorInterleaved<32>,
+ ElementOutput,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape0,
+ WarpShape0,
+ InstructionShape,
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 64 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+ using Gemm1 = cutlass::gemm::device::Gemm<
+ int8_t,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ int8_t,
+ cutlass::layout::RowMajorInterleaved<32>,
+ ElementOutput,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape1,
+ WarpShape1,
+ InstructionShape,
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 64 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+
+ B2bInterleavedNonFusedGemmRun nonFusedGemm;
+
+ std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n";
+ bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
+ if(pass)
+ std::cout << "Pass\n";
+ else
+ std::cout << "Fail\n";
+}
+
+void run_fused_gemm_s8() {
+
+ using ElementOutput = int8_t;
+ using ElementAccumulator = int32_t;
+ using ElementCompute = float;
+
+ cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576);
+ cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64);
+ ElementCompute alpha0 = ElementCompute(2);
+ ElementCompute beta0 = ElementCompute(0);
+ ElementCompute alpha1 = ElementCompute(2);
+ ElementCompute beta1 = ElementCompute(1);
+
+ using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>;
+ using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>;
+ using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>;
+ using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>;
+ using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
+
+ using EpilogueOutputOp0 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ InstructionShape::kM * InstructionShape::kN / 32,
+ ElementAccumulator,
+ ElementCompute
+ >;
+
+ using EpilogueOutputOp1 =
+ cutlass::epilogue::thread::LinearCombinationRelu<
+ ElementOutput,
+ 64 / cutlass::sizeof_bits::value,
+ ElementAccumulator,
+ ElementCompute
+ >;
+
+
+
+ using B2bGemm = cutlass::gemm::device::B2bGemm<
+ int8_t,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ int8_t,
+ cutlass::layout::RowMajorInterleaved<32>,
+ ElementOutput,
+ cutlass::layout::ColumnMajorInterleaved<32>,
+ ElementAccumulator,
+ cutlass::arch::OpClassTensorOp,
+ cutlass::arch::Sm75,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ EpilogueOutputOp0,
+ EpilogueOutputOp1,
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
+ 2
+ >;
+
+ B2bInterleavedFusedGemmRun fusedGemm;
+
+ std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n";
+ bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1);
+ if(passed)
+ std::cout << "Pass\n";
+ else
+ std::cout << "Fail\n";
+
+}
+////////////////////////////////////////////////////////////////////////////////
+
+#endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
diff --git a/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h
new file mode 100644
index 00000000..906cabb4
--- /dev/null
+++ b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h
@@ -0,0 +1,633 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/tensor_view_io.h"
+#include "cutlass/util/distribution.h"
+#include "cutlass/util/reference/host/tensor_fill.h"
+#include "cutlass/util/reference/host/tensor_copy.h"
+#include "cutlass/util/reference/host/tensor_compare.h"
+#include "cutlass/util/reference/host/tensor_norm.h"
+#include "cutlass/util/host_reorder.h"
+#include "cutlass/util/reference/device/gemm.h"
+#include "helper.h"
+
+#define CHECK_GT(val1, val2) \
+ if((val1) <= (val2)) \
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n";
+#define CHECK_TRUE(val) \
+ if(!(val)) \
+ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n";
+
+template
+struct B2bInterleavedNonFusedGemmRun
+{
+
+ using Gemm0 = Gemm0_;
+ using Gemm1 = Gemm1_;
+ using ElementAccumulator = typename Gemm0::ElementAccumulator;
+ using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute;
+
+ /// Initialization
+ cutlass::Distribution::Kind init_A;
+ cutlass::Distribution::Kind init_B;
+ cutlass::Distribution::Kind init_C;
+ uint64_t seed;
+
+ //
+ // Methods
+ //
+
+ B2bInterleavedNonFusedGemmRun(
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ uint64_t seed_ = 2080
+ ):
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
+
+ /// Helper to initialize a tensor view
+ template
+ bool initialize_tensor(
+ cutlass::TensorView view,
+ cutlass::Distribution::Kind dist_kind,
+ uint64_t seed) {
+
+ if (dist_kind == cutlass::Distribution::Uniform) {
+
+ cutlass::reference::host::TensorFillRandomUniform(
+ view, seed, 2, -2, 0);
+ }
+ else if (dist_kind == cutlass::Distribution::Identity) {
+
+ cutlass::reference::host::TensorFillIdentity(view);
+ }
+ else if (dist_kind == cutlass::Distribution::Sequential) {
+
+ cutlass::reference::host::BlockFillSequential(
+ view.data(), view.capacity());
+ }
+ else {
+ // TODO: Implement the rest
+ std::cerr << "Not implemented\n";
+ return false;
+ }
+
+ return true;
+ }
+
+
+
+
+ /// Executes one test
+ bool run(
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
+ ElementCompute beta0 = ElementCompute(0),
+ ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute beta1 = ElementCompute(0),
+ bool relu = true) {
+
+ //
+ // Allocate the GEMM workspace
+ //
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementA,
+ typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementB,
+ typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementB,
+ typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm0::ElementC,
+ typename Gemm0::LayoutC> reference_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementB,
+ typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementB,
+ typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename Gemm1::ElementC,
+ typename Gemm1::LayoutC> reference_D1(problem_size_1.mn());
+
+
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
+
+ //Reorder B0 and B1
+ cutlass::reorder_column(
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
+ cutlass::reorder_column(
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
+
+ cutlass::reference::host::TensorFill(
+ tensor_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ tensor_D1.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D1.host_view());
+
+ tensor_A0.sync_device();
+ tensor_B0.sync_device();
+ tensor_B0_reordered.sync_device();
+ tensor_C0.sync_device();
+ tensor_D0.sync_device();
+ tensor_B1.sync_device();
+ tensor_B1_reordered.sync_device();
+ tensor_C1.sync_device();
+ tensor_D1.sync_device();
+ reference_D0.sync_device();
+ reference_D1.sync_device();
+
+ //
+ // Initialize the GEMM operator
+ //
+
+ typename Gemm0::Arguments arguments_0{
+ problem_size_0,
+ tensor_A0.device_ref(),
+ tensor_B0_reordered.device_ref(),
+ tensor_C0.device_ref(),
+ tensor_D0.device_ref(),
+ {alpha0, beta0}
+ };
+
+ typename Gemm1::Arguments arguments_1{
+ problem_size_1,
+ tensor_D0.device_ref(),
+ tensor_B1_reordered.device_ref(),
+ tensor_C1.device_ref(),
+ tensor_D1.device_ref(),
+ {alpha1, beta1}
+ };
+
+
+ Gemm0 gemm_op_0;
+ Gemm1 gemm_op_1;
+
+ cutlass::Status status = gemm_op_0.initialize(arguments_0);
+
+ CUTLASS_CHECK(status);
+
+ status = gemm_op_1.initialize(arguments_1);
+
+ CUTLASS_CHECK(status);
+ //
+ // Run the GEMM
+ //
+ cudaEvent_t start, stop1, stop2;
+ cudaEventCreate(&start);
+ cudaEventCreate(&stop1);
+ cudaEventCreate(&stop2);
+
+ cudaEventRecord(start);
+
+ for(int i = 0; i < 100; i++) {
+ status = gemm_op_0();
+
+ CUTLASS_CHECK(status);
+ }
+ cudaEventRecord(stop1);
+
+ for(int i = 0; i < 100; i++) {
+ status = gemm_op_1();
+
+ CUTLASS_CHECK(status);
+ }
+
+ cudaEventRecord(stop2);
+ cudaDeviceSynchronize();
+ float gemm0Time, gemm1Time, totalTime;
+ cudaEventElapsedTime(&gemm0Time, start, stop1);
+ cudaEventElapsedTime(&gemm1Time, stop1, stop2);
+ cudaEventElapsedTime(&totalTime, start, stop2);
+ std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n";
+ std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n";
+ std::cout << "total time " << totalTime / 100.0 << " ms\n";
+
+ tensor_D0.sync_host();
+ tensor_D1.sync_host();
+
+ //
+ // Verify
+ //
+ cutlass::reference::device::Gemm<
+ typename Gemm0::ElementA, typename Gemm0::LayoutA,
+ typename Gemm0::ElementB, typename Gemm0::LayoutB,
+ typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute,
+ ElementAccumulator, typename Gemm0::Operator>
+ reference_gemm_0;
+
+ cutlass::reference::device::Gemm<
+ typename Gemm1::ElementA, typename Gemm1::LayoutA,
+ typename Gemm1::ElementB, typename Gemm1::LayoutB,
+ typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute,
+ ElementAccumulator, typename Gemm1::Operator>
+ reference_gemm_1;
+
+ reference_gemm_0(
+ problem_size_0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
+ tensor_C0.device_ref(),
+ reference_D0.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ }
+
+ reference_gemm_1(
+ problem_size_1,
+ alpha1,
+ tensor_D0.device_ref(),
+ tensor_B1.device_ref(),
+ beta1,
+ tensor_C1.device_ref(),
+ reference_D1.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ }
+
+ cudaDeviceSynchronize();
+ reference_D0.sync_host();
+ reference_D1.sync_host();
+
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
+
+ bool passed = cutlass::reference::host::TensorEquals(
+ reference_D1.host_view(),
+ tensor_D1.host_view());
+
+ CHECK_TRUE(passed);
+ if (!passed) {
+
+ std::stringstream fname;
+
+ fname << "error_B2bGemm_device_interleaved_nonfused.txt";
+ std::cerr << "Dumping results in " << fname.str() << "\n";
+
+ std::ofstream file(fname.str());
+
+ file
+ << "A0 =\n" << tensor_A0.host_view()
+ << "\nB0 =\n" << tensor_B0.host_view()
+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
+ << "\nC0 =\n" << tensor_C0.host_view()
+ << "\nD0 =\n" << tensor_D0.host_view()
+ << "\nB1 =\n" << tensor_B1.host_view()
+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
+ << "\nC1 =\n" << tensor_C1.host_view()
+ << "\n\nReference =\n" << reference_D1.host_view()
+ << "\nComputed =\n" << tensor_D1.host_view();
+ }
+
+ return passed;
+ }
+};
+
+template
+struct B2bInterleavedFusedGemmRun
+{
+
+ using B2bGemm = B2bGemm_;
+ using ElementAccumulator = typename B2bGemm::ElementAccumulator;
+ using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute;
+
+ /// Initialization
+ cutlass::Distribution::Kind init_A;
+ cutlass::Distribution::Kind init_B;
+ cutlass::Distribution::Kind init_C;
+ uint64_t seed;
+
+ //
+ // Methods
+ //
+
+ B2bInterleavedFusedGemmRun(
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
+ uint64_t seed_ = 2080
+ ):
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
+
+ /// Helper to initialize a tensor view
+ template
+ bool initialize_tensor(
+ cutlass::TensorView view,
+ cutlass::Distribution::Kind dist_kind,
+ uint64_t seed) {
+
+ if (dist_kind == cutlass::Distribution::Uniform) {
+
+ cutlass::reference::host::TensorFillRandomUniform(
+ view, seed, 2, -2, 0);
+ }
+ else if (dist_kind == cutlass::Distribution::Identity) {
+
+ cutlass::reference::host::TensorFillIdentity(view);
+ }
+ else if (dist_kind == cutlass::Distribution::Sequential) {
+
+ cutlass::reference::host::BlockFillSequential(
+ view.data(), view.capacity());
+ }
+ else {
+ // TODO: Implement the rest
+ std::cerr << "Not implemented\n";
+ return false;
+ }
+
+ return true;
+ }
+
+
+
+
+ /// Executes one test
+ bool run(
+ cutlass::gemm::GemmCoord problem_size_0,
+ cutlass::gemm::GemmCoord problem_size_1,
+ ElementCompute alpha0 = ElementCompute(1),
+ ElementCompute beta0 = ElementCompute(0),
+ ElementCompute alpha1 = ElementCompute(1),
+ ElementCompute beta1 = ElementCompute(0),
+ bool relu = true) {
+
+ //
+ // Allocate the GEMM workspace
+ //
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementA,
+ typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn());
+
+// cutlass::HostTensor<
+// typename B2bGemm::ElementC,
+// typename B2bGemm::LayoutC> tensor_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementB,
+ typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn());
+
+ cutlass::HostTensor<
+ typename B2bGemm::ElementC,
+ typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn());
+
+
+ CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019));
+ CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018));
+ CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017));
+ CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016));
+ CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015));
+
+ //Reorder B0
+ cutlass::reorder_column(
+ tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0);
+ cutlass::reorder_column(
+ tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1);
+
+ cutlass::reference::host::TensorFill(
+ tensor_D1.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D0.host_view());
+ cutlass::reference::host::TensorFill(
+ reference_D1.host_view());
+
+ tensor_A0.sync_device();
+ tensor_B0.sync_device();
+ tensor_B0_reordered.sync_device();
+ tensor_C0.sync_device();
+ //tensor_D0.sync_device();
+ tensor_B1.sync_device();
+ tensor_B1_reordered.sync_device();
+ tensor_C1.sync_device();
+ tensor_D1.sync_device();
+ reference_D0.sync_device();
+ reference_D1.sync_device();
+
+ //
+ // Initialize the GEMM operator
+ //
+
+ typename B2bGemm::Arguments arguments{
+ problem_size_0,
+ problem_size_1,
+ tensor_A0.device_ref(),
+ tensor_B0_reordered.device_ref(),
+ tensor_C0.device_ref(),
+ tensor_B1_reordered.device_ref(),
+ tensor_C1.device_ref(),
+ tensor_D1.device_ref(),
+ {alpha0, beta0},
+ {alpha1, beta1},
+ 1, /*threadblock_swizzle_k_tile*/
+ };
+
+ B2bGemm b2b_gemm_op;
+
+ cutlass::Status status = b2b_gemm_op.initialize(arguments);
+
+ CUTLASS_CHECK(status);
+
+ //
+ // Run the GEMM
+ //
+
+ cudaEvent_t start, stop;
+ cudaEventCreate(&start);
+ cudaEventCreate(&stop);
+
+ cudaEventRecord(start);
+
+ for(int i = 0; i < 100; i++) {
+ status = b2b_gemm_op();
+
+ CUTLASS_CHECK(status);
+ }
+
+ cudaEventRecord(stop);
+ cudaDeviceSynchronize();
+ float gemmTime;
+ cudaEventElapsedTime(&gemmTime, start, stop);
+ std::cout << "time " << gemmTime / 100.0 << " ms\n";
+
+ //tensor_D0.sync_host();
+ tensor_D1.sync_host();
+
+ //
+ // Verify
+ //
+ cutlass::reference::device::Gemm<
+ typename B2bGemm::ElementA, typename B2bGemm::LayoutA,
+ typename B2bGemm::ElementB, typename B2bGemm::LayoutB,
+ typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute,
+ ElementAccumulator, typename B2bGemm::Operator>
+ reference_gemm_0, reference_gemm_1;
+
+ reference_gemm_0(
+ problem_size_0,
+ alpha0,
+ tensor_A0.device_ref(),
+ tensor_B0.device_ref(),
+ beta0,
+ tensor_C0.device_ref(),
+ reference_D0.device_ref()
+ );
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D0.device_view());
+ }
+
+ reference_gemm_1(
+ problem_size_1,
+ alpha1,
+ reference_D0.device_ref(),
+ tensor_B1.device_ref(),
+ beta1,
+ tensor_C1.device_ref(),
+ reference_D1.device_ref()
+ );
+
+
+ if(relu) {
+ cutlass::reference::device::TensorReLu(reference_D1.device_view());
+ }
+
+ cudaDeviceSynchronize();
+ reference_D0.sync_host();
+ reference_D1.sync_host();
+
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0);
+ CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0);
+
+ bool passed = cutlass::reference::host::TensorEquals(
+ reference_D1.host_view(),
+ tensor_D1.host_view());
+
+ CHECK_TRUE(passed);
+ if (!passed) {
+
+ std::stringstream fname;
+
+ fname << "error_B2bGemm_device_interleaved_fused.txt";
+ std::cerr << "Dumping results in " << fname.str() << "\n";
+
+ std::ofstream file(fname.str());
+
+ file
+ << "A0 =\n" << tensor_A0.host_view()
+ << "\nB0 =\n" << tensor_B0.host_view()
+ << "\nB0_reordered =\n" << tensor_B0_reordered.host_view()
+ << "\nC0 =\n" << tensor_C0.host_view()
+// << "\nD0 =\n" << tensor_D0.host_view()
+ << "\nB1 =\n" << tensor_B1.host_view()
+ << "\nB1_reordered =\n" << tensor_B1_reordered.host_view()
+ << "\nC1 =\n" << tensor_C1.host_view()
+ << "\n\nReference =\n" << reference_D1.host_view()
+ << "\nComputed =\n" << tensor_D1.host_view();
+ }
+
+ return passed;
+ }
+
+};
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_fused_two_gemms/device/b2b_gemm.h b/examples/13_fused_two_gemms/device/b2b_gemm.h
new file mode 100644
index 00000000..3f161435
--- /dev/null
+++ b/examples/13_fused_two_gemms/device/b2b_gemm.h
@@ -0,0 +1,439 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/arch/arch.h"
+#include "cutlass/device_kernel.h"
+
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+
+#include "cutlass/gemm/device/default_gemm_configuration.h"
+#include "cutlass/epilogue/thread/linear_combination_relu.h"
+
+#include "kernel/b2b_gemm.h"
+#include "kernel/default_b2b_gemm.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace device {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// Element type for A matrix operand
+ typename ElementA_,
+ /// Layout type for A matrix operand
+ typename LayoutA_,
+ /// Element type for B matrix operand
+ typename ElementB_,
+ /// Layout type for B matrix operand
+ typename LayoutB_,
+ /// Element type for C and D matrix operands
+ typename ElementC_,
+ /// Layout type for C and D matrix operands
+ typename LayoutC_,
+ /// Element type for internal accumulation
+ typename ElementAccumulator_ = ElementC_,
+ /// Operator class tag
+ typename OperatorClass_ = arch::OpClassSimt,
+ /// Tag indicating architecture to tune for
+ typename ArchTag_ = arch::Sm70,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::ThreadblockShape,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::ThreadblockShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::WarpShape,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::WarpShape,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::InstructionShape,
+ /// Epilogue output operator
+ typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::EpilogueOutputOp,
+ /// Epilogue output operator
+ typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::EpilogueOutputOp,
+ /// Threadblock-level swizzling operator
+ typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
+ /// Number of stages used in the pipelined mainloop
+ int Stages =
+ DefaultGemmConfiguration::kStages,
+ /// Access granularity of A matrix in units of elements
+ int AlignmentA =
+ DefaultGemmConfiguration::kAlignmentA,
+ /// Access granularity of B matrix in units of elements
+ int AlignmentB =
+ DefaultGemmConfiguration::kAlignmentB,
+ /// If true, kernel supports split-K with serial reduction
+ bool SplitKSerial = false,
+ /// Operation performed by GEMM
+ typename Operator_ = typename DefaultGemmConfiguration<
+ OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
+ ElementAccumulator_>::Operator,
+ /// Whether Beta is zero or not
+ bool IsBetaZero = false>
+class B2bGemm {
+ public:
+
+ using ElementA = ElementA_;
+ using LayoutA = LayoutA_;
+ using TensorRefA = TensorRef;
+ using ElementB = ElementB_;
+ using LayoutB = LayoutB_;
+ using TensorRefB = TensorRef;
+ using ElementC = ElementC_;
+ using LayoutC = LayoutC_;
+ using TensorRefC = TensorRef;
+ using TensorRefD = TensorRef;
+ using ElementAccumulator = ElementAccumulator_;
+ using OperatorClass = OperatorClass_;
+ using ArchTag = ArchTag_;
+ using ThreadblockShape0 = ThreadblockShape0_;
+ using ThreadblockShape1 = ThreadblockShape1_;
+ using WarpShape0 = WarpShape0_;
+ using WarpShape1 = WarpShape1_;
+ using InstructionShape = InstructionShape_;
+ using EpilogueOutputOp0 = EpilogueOutputOp0_;
+ using EpilogueOutputOp1 = EpilogueOutputOp1_;
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
+ using Operator = Operator_;
+ static int const kStages = Stages;
+ static int const kAlignmentA = AlignmentA;
+ static int const kAlignmentB = AlignmentB;
+ static int const kAlignmentC = EpilogueOutputOp1::kCount;
+ static bool const kSplitKSerial = SplitKSerial;
+ static bool const kIsBetaZero = IsBetaZero;
+ static ComplexTransform const kTransformA = ComplexTransform::kNone;
+ static ComplexTransform const kTransformB = ComplexTransform::kNone;
+
+ /// Define the kernel
+ using B2bGemmKernel = typename kernel::DefaultB2bGemm<
+ ElementA,
+ LayoutA,
+ kAlignmentA,
+ ElementB,
+ LayoutB,
+ kAlignmentB,
+ ElementC,
+ LayoutC,
+ ElementAccumulator,
+ OperatorClass,
+ ArchTag,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ EpilogueOutputOp0,
+ EpilogueOutputOp1,
+ ThreadblockSwizzle,
+ kStages,
+ kSplitKSerial,
+ Operator,
+ kIsBetaZero
+ >::B2bGemmKernel;
+
+ /// Argument structure
+ struct Arguments {
+
+ //
+ // Data members
+ //
+
+ GemmCoord problem_size_0;
+ GemmCoord problem_size_1;
+ TensorRef ref_A0;
+ TensorRef ref_B0;
+ TensorRef ref_C0;
+ TensorRef ref_B1;
+ TensorRef ref_C1;
+ TensorRef ref_D1;
+ typename EpilogueOutputOp0::Params epilogue0;
+ typename EpilogueOutputOp1::Params epilogue1;
+ int split_k_slices;
+
+ //
+ // Methods
+ //
+
+ /// Default ctor
+ CUTLASS_HOST_DEVICE
+ Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
+
+ }
+
+ /// Constructs an Arguments structure
+ CUTLASS_HOST_DEVICE
+ Arguments(
+ GemmCoord problem_size_0_,
+ GemmCoord problem_size_1_,
+ TensorRef ref_A0_,
+ TensorRef ref_B0_,
+ TensorRef ref_C0_,
+ TensorRef ref_B1_,
+ TensorRef ref_C1_,
+ TensorRef ref_D1_,
+ typename EpilogueOutputOp0::Params epilogue0_ =
+ typename EpilogueOutputOp0::Params(),
+ typename EpilogueOutputOp1::Params epilogue1_ =
+ typename EpilogueOutputOp1::Params(),
+ int split_k_slices_ = 1
+ ):
+ problem_size_0(problem_size_0_),
+ problem_size_1(problem_size_1_),
+ ref_A0(ref_A0_),
+ ref_B0(ref_B0_),
+ ref_C0(ref_C0_),
+ ref_B1(ref_B1_),
+ ref_C1(ref_C1_),
+ ref_D1(ref_D1_),
+ epilogue0(epilogue0_),
+ epilogue1(epilogue1_),
+ split_k_slices(split_k_slices_) {
+
+ }
+ };
+
+private:
+
+ /// Kernel parameters object
+ typename B2bGemmKernel::Params params_;
+
+public:
+
+ /// Constructs the GEMM.
+ B2bGemm() { }
+
+ /// Determines whether the GEMM can execute the given problem.
+ static Status can_implement(Arguments const &args) {
+
+ if (!kSplitKSerial && args.split_k_slices > 1) {
+ return Status::kErrorInvalidProblem;
+ }
+
+ Status status = B2bGemmKernel::can_implement(
+ args.problem_size_0,
+ args.problem_size_1,
+ args.ref_A0.non_const_ref(),
+ args.ref_B0.non_const_ref(),
+ args.ref_C0.non_const_ref(),
+ args.ref_B1.non_const_ref(),
+ args.ref_C1.non_const_ref(),
+ args.ref_D1
+ );
+
+ if (status != Status::kSuccess) {
+ return status;
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Gets the workspace size
+ static size_t get_workspace_size(Arguments const &args) {
+
+ size_t bytes = 0;
+
+ // Determine grid shape
+ ThreadblockSwizzle threadblock_swizzle;
+
+ cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
+ args.problem_size_0,
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
+ args.split_k_slices);
+
+ if (kSplitKSerial && args.split_k_slices > 1) {
+
+
+ bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
+ }
+
+ return bytes;
+ }
+
+ /// Initializes GEMM state from arguments.
+ Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
+
+ // Determine grid shape
+ ThreadblockSwizzle threadblock_swizzle;
+
+ cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
+ args.problem_size_0,
+ {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
+ args.split_k_slices);
+// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
+// args.problem_size_1,
+// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
+// args.split_k_slices);
+
+ if (kSplitKSerial) {
+ if (args.split_k_slices > 1) {
+ if (!workspace) {
+ return Status::kErrorWorkspaceNull;
+ }
+
+ size_t bytes = get_workspace_size(args);
+
+ cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
+
+ if (result != cudaSuccess) {
+ return Status::kErrorInternal;
+ }
+ }
+ }
+ else {
+
+ if (args.split_k_slices > 1) {
+ return Status::kErrorInvalidProblem;
+ }
+ }
+
+ // Initialize the Params structure
+ params_ = typename B2bGemmKernel::Params{
+ args.problem_size_0,
+ args.problem_size_1,
+ grid_shape,
+ args.ref_A0.non_const_ref(),
+ args.ref_B0.non_const_ref(),
+ args.ref_C0.non_const_ref(),
+ args.ref_B1.non_const_ref(),
+ args.ref_C1.non_const_ref(),
+ args.ref_D1,
+ args.epilogue0,
+ args.epilogue1,
+ static_cast(workspace),
+ };
+
+ return Status::kSuccess;
+ }
+
+ /// Lightweight update given a subset of arguments
+ Status update(Arguments const &args, void *workspace = nullptr) {
+
+ if (kSplitKSerial && args.split_k_slices > 1) {
+ if (!workspace) {
+ return Status::kErrorWorkspaceNull;
+ }
+ }
+
+ params_.ref_A0.reset(args.ref_A.non_const_ref().data());
+ params_.ref_B0.reset(args.ref_B.non_const_ref().data());
+ params_.ref_C0.reset(args.ref_C.non_const_ref().data());
+ params_.ref_B1.reset(args.ref_B.non_const_ref().data());
+ params_.ref_C1.reset(args.ref_C.non_const_ref().data());
+ params_.ref_D1.reset(args.ref_D.data());
+ params_.output_op_0 = args.epilogue0;
+ params_.output_op_1 = args.epilogue1;
+ params_.semaphore = static_cast(workspace);
+
+ return Status::kSuccess;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status run(cudaStream_t stream = nullptr) {
+
+ ThreadblockSwizzle threadblock_swizzle;
+
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
+ dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
+
+ cudaError_t result;
+
+ int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
+ if (smem_size >= (48 << 10)) {
+ result = cudaFuncSetAttribute(Kernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ smem_size);
+
+ if (result != cudaSuccess) {
+ return Status::kErrorInternal;
+ }
+
+ result = cudaFuncSetAttribute(
+ Kernel,
+ cudaFuncAttributePreferredSharedMemoryCarveout, 100);
+
+ if (result != cudaSuccess) {
+ return Status::kErrorInternal;
+ }
+ }
+
+ cutlass::Kernel<<>>(params_);
+
+ result = cudaGetLastError();
+
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status operator()(cudaStream_t stream = nullptr) {
+ return run(stream);
+ }
+
+ /// Runs the kernel using initialized state.
+ Status operator()(
+ Arguments const &args,
+ void *workspace = nullptr,
+ cudaStream_t stream = nullptr) {
+
+ Status status = initialize(args, workspace);
+
+ if (status == Status::kSuccess) {
+ status = run(stream);
+ }
+
+ return status;
+ }
+};
+
+} // namespace device
+} // namespace gemm
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_fused_two_gemms/fused_gemm.cu b/examples/13_fused_two_gemms/fused_gemm.cu
new file mode 100644
index 00000000..8f5d4f2c
--- /dev/null
+++ b/examples/13_fused_two_gemms/fused_gemm.cu
@@ -0,0 +1,74 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/**
+*/
+
+#include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h"
+#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h"
+
+int run() {
+
+ cudaDeviceProp props;
+
+ cudaError_t error = cudaGetDeviceProperties(&props, 0);
+ if (error != cudaSuccess) {
+ std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
+ return -1;
+ }
+
+ if (!(props.major * 10 + props.minor >= 75)) {
+ std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75."
+ << std::endl;
+
+ // Returning zero so this test passes on older Toolkits. Its actions are no-op.
+ return 0;
+ }
+
+#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED)
+ run_nonfused_gemm_f16();
+ run_fused_gemm_f16();
+ run_nonfused_gemm_s8();
+ run_fused_gemm_s8();
+#endif
+
+ return 0;
+}
+
+int main() {
+ // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2.
+ //
+ // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples.
+ if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) {
+ std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl;
+
+ // Returning zero so this test passes on older Toolkits. Its actions are no-op.
+ return 0;
+ }
+ else {
+ return run();
+ }
+}
+
diff --git a/examples/13_fused_two_gemms/kernel/b2b_gemm.h b/examples/13_fused_two_gemms/kernel/b2b_gemm.h
new file mode 100644
index 00000000..d106fa46
--- /dev/null
+++ b/examples/13_fused_two_gemms/kernel/b2b_gemm.h
@@ -0,0 +1,407 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_coord.h"
+#include "cutlass/semaphore.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace kernel {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate
+ typename Epilogue_, ///! Epilogue
+ typename ThreadblockSwizzle_, ///! Threadblock swizzling function
+ bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
+>
+struct B2bGemm {
+
+ using B2bMma = B2bMma_;
+ using Epilogue = Epilogue_;
+ using OutputOp0 = typename B2bMma::OutputOp;
+ using OutputOp1 = typename Epilogue::OutputOp;
+ using ThreadblockSwizzle = ThreadblockSwizzle_;
+ static bool const kSplitKSerial = SplitKSerial;
+
+ /// Warp count (concept: GemmShape)
+ using WarpCount0 = typename B2bMma::WarpCount0;
+ static int const kThreadCount = 32 * WarpCount0::kCount;
+
+ /// Parameters structure
+ struct Params {
+ cutlass::gemm::GemmCoord problem_size_0;
+ cutlass::gemm::GemmCoord problem_size_1;
+ cutlass::gemm::GemmCoord grid_tiled_shape;
+ typename B2bMma::IteratorA0::Params params_A0;
+ typename B2bMma::IteratorA0::TensorRef ref_A0;
+ typename B2bMma::IteratorB0::Params params_B0;
+ typename B2bMma::IteratorB0::TensorRef ref_B0;
+ typename Epilogue::OutputTileIterator::Params params_C0;
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0;
+ typename B2bMma::IteratorB1::Params params_B1;
+ typename B2bMma::IteratorB1::TensorRef ref_B1;
+ typename Epilogue::OutputTileIterator::Params params_C1;
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1;
+ typename Epilogue::OutputTileIterator::Params params_D1;
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1;
+ typename OutputOp0::Params output_op_0;
+ typename OutputOp1::Params output_op_1;
+ int *semaphore;
+ int gemm_k_iterations_0;
+ int gemm_k_size_0;
+ int gemm_k_iterations_1;
+ int gemm_k_size_1;
+
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ Params(): semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0),
+ gemm_k_iterations_1(0), gemm_k_size_1(0) { }
+
+ CUTLASS_HOST_DEVICE
+ Params(
+ cutlass::gemm::GemmCoord const & problem_size_0,
+ cutlass::gemm::GemmCoord const & problem_size_1,
+ cutlass::gemm::GemmCoord const & grid_tiled_shape,
+ typename B2bMma::IteratorA0::TensorRef ref_A0,
+ typename B2bMma::IteratorB0::TensorRef ref_B0,
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0,
+ typename B2bMma::IteratorB1::TensorRef ref_B1,
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1,
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1,
+ typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(),
+ typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(),
+ int *workspace = nullptr
+ ):
+ problem_size_0(problem_size_0),
+ problem_size_1(problem_size_1),
+ grid_tiled_shape(grid_tiled_shape),
+ params_A0(ref_A0.layout()),
+ ref_A0(ref_A0),
+ params_B0(ref_B0.layout()),
+ ref_B0(ref_B0),
+ params_C0(ref_C0.layout()),
+ ref_C0(ref_C0),
+ params_B1(ref_B1.layout()),
+ ref_B1(ref_B1),
+ params_C1(ref_C1.layout()),
+ ref_C1(ref_C1),
+ params_D1(ref_D1.layout()),
+ ref_D1(ref_D1),
+ output_op_0(output_op_0),
+ output_op_1(output_op_1) {
+
+ int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
+ int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
+ gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK;
+ int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
+ int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
+ gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK;
+
+ semaphore = workspace;
+ }
+ };
+
+ /// Shared memory storage structure
+ union SharedStorage {
+ typename B2bMma::B2bMmaSharedStorage main_loop;
+ typename Epilogue::SharedStorage epilogue;
+ };
+
+ //
+ // Methods
+ //
+
+ CUTLASS_HOST_DEVICE
+ B2bGemm() { }
+
+ /// Determines whether kernel satisfies alignment
+ static Status can_implement(
+ cutlass::gemm::GemmCoord const & problem_size_0,
+ cutlass::gemm::GemmCoord const & problem_size_1,
+ typename B2bMma::IteratorA0::TensorRef ref_A0,
+ typename B2bMma::IteratorB0::TensorRef ref_B0,
+ typename Epilogue::OutputTileIterator::TensorRef ref_C0,
+ typename B2bMma::IteratorB1::TensorRef ref_B1,
+ typename Epilogue::OutputTileIterator::TensorRef ref_C1,
+ typename Epilogue::OutputTileIterator::TensorRef ref_D1) {
+
+ static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements;
+ static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements;
+ static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
+
+ if (!TensorRef_aligned(ref_A0, kAlignmentA)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if (!TensorRef_aligned(ref_B0, kAlignmentB)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if (!TensorRef_aligned(ref_C0, kAlignmentC)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if (!TensorRef_aligned(ref_B1, kAlignmentB)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if (!TensorRef_aligned(ref_C1, kAlignmentC)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if (!TensorRef_aligned(ref_D1, kAlignmentC)) {
+ return Status::kErrorMisalignedOperand;
+ }
+
+ if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) ||
+ (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) ||
+ (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) ||
+ (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) ||
+ (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) ||
+ (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) {
+
+ return Status::kErrorMisalignedOperand;
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Executes one GEMM
+ CUTLASS_DEVICE
+ void operator()(Params const ¶ms, SharedStorage &shared_storage) {
+
+ // Compute threadblock location
+ ThreadblockSwizzle threadblock_swizzle;
+
+ cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
+
+ // Early exit if CTA is out of range
+ if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
+ params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
+
+ return;
+ }
+
+ // Compute initial location in logical coordinates
+ cutlass::MatrixCoord tb_offset_A0{
+ threadblock_tile_offset.m() * B2bMma::Shape0::kM,
+ threadblock_tile_offset.k() * params.gemm_k_size_0,
+ };
+
+ cutlass::MatrixCoord tb_offset_B0{
+ threadblock_tile_offset.k() * params.gemm_k_size_0,
+ threadblock_tile_offset.n() * B2bMma::Shape0::kN
+ };
+
+ cutlass::MatrixCoord tb_offset_B1{
+ threadblock_tile_offset.k() * params.gemm_k_size_1,
+ threadblock_tile_offset.n() * B2bMma::Shape1::kN
+ };
+
+ // Problem size is a function of threadblock index in the K dimension
+ int problem_size_k_0 = min(
+ params.problem_size_0.k(),
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0);
+
+ // Compute threadblock-scoped matrix multiply-add
+ int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK;
+
+ // Problem size is a function of threadblock index in the K dimension
+ int problem_size_k_1 = min(
+ params.problem_size_1.k(),
+ (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1);
+
+ // Compute threadblock-scoped matrix multiply-add
+// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK;
+
+
+ // Compute position within threadblock
+ int thread_idx = threadIdx.x;
+
+ // Construct iterators to A and B operands
+ typename B2bMma::IteratorA0 iterator_A0(
+ params.params_A0,
+ params.ref_A0.data(),
+ {params.problem_size_0.m(), problem_size_k_0},
+ thread_idx,
+ tb_offset_A0);
+
+ typename B2bMma::IteratorB0 iterator_B0(
+ params.params_B0,
+ params.ref_B0.data(),
+ {problem_size_k_0, params.problem_size_0.n()},
+ thread_idx,
+ tb_offset_B0);
+
+ typename B2bMma::IteratorB1 iterator_B1(
+ params.params_B1,
+ params.ref_B1.data(),
+ {problem_size_k_1, params.problem_size_1.n()},
+ thread_idx,
+ tb_offset_B1);
+
+
+ // Broadcast the warp_id computed by lane 0 to ensure dependent code
+ // is compiled as warp-uniform.
+ int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0);
+ int lane_idx = threadIdx.x % 32;
+
+ //
+ // Main loop
+ //
+
+ OutputOp0 output_op_0(params.output_op_0);
+
+ // Construct thread-scoped matrix multiply
+ B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
+
+ typename B2bMma::FragmentC0 src_accum;
+ typename B2bMma::FragmentC1 accumulators;
+
+ src_accum.clear();
+ accumulators.clear();
+
+ if (!kSplitKSerial || gemm_k_iterations_0 > 0) {
+ // Compute threadblock-scoped matrix multiply-add
+ b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, iterator_B1, src_accum, output_op_0);
+ }
+
+ //
+ // Epilogue
+ //
+
+ OutputOp1 output_op_1(params.output_op_1);
+
+ //
+ // Masked tile iterators constructed from members
+ //
+
+ threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
+
+ //assume identity swizzle
+ MatrixCoord threadblock_offset(
+ threadblock_tile_offset.m() * B2bMma::Shape1::kM,
+ threadblock_tile_offset.n() * B2bMma::Shape1::kN
+ );
+
+ int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
+
+ // Construct the semaphore.
+ Semaphore semaphore(params.semaphore + block_idx, thread_idx);
+
+ // If performing a reduction via split-K, fetch the initial synchronization
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+ // Fetch the synchronization lock initially but do not block.
+ semaphore.fetch();
+
+ // Indicate which position in a serial reduction the output operator is currently updating
+ output_op_1.set_k_partition(threadblock_tile_offset.k());
+ }
+
+ // Tile iterator loading from source tensor.
+ typename Epilogue::OutputTileIterator iterator_C1(
+ params.params_C1,
+ params.ref_C1.data(),
+ params.problem_size_1.mn(),
+ thread_idx,
+ threadblock_offset
+ );
+
+ // Tile iterator writing to destination tensor.
+ typename Epilogue::OutputTileIterator iterator_D1(
+ params.params_D1,
+ params.ref_D1.data(),
+ params.problem_size_1.mn(),
+ thread_idx,
+ threadblock_offset
+ );
+
+ Epilogue epilogue(
+ shared_storage.epilogue,
+ thread_idx,
+ warp_idx,
+ lane_idx);
+
+ // Wait on the semaphore - this latency may have been covered by iterator construction
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+ // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
+ if (threadblock_tile_offset.k()) {
+ iterator_C1 = iterator_D1;
+ }
+
+ semaphore.wait(threadblock_tile_offset.k());
+
+ __threadfence();
+ }
+
+ // Execute the epilogue operator to update the destination tensor.
+ epilogue(output_op_1, iterator_D1, accumulators, iterator_C1);
+
+ //
+ // Release the semaphore
+ //
+
+ if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
+
+ int lock = 0;
+ if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
+
+ // The final threadblock resets the semaphore for subsequent grids.
+ lock = 0;
+ }
+ else {
+ // Otherwise, the semaphore is incremented
+ lock = threadblock_tile_offset.k() + 1;
+ }
+
+ __threadfence();
+ semaphore.release(lock);
+ }
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace kernel
+} // namespace gemm
+} // namespace cutlass
+
diff --git a/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h
new file mode 100644
index 00000000..45b2d545
--- /dev/null
+++ b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h
@@ -0,0 +1,296 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ *modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice,
+ *this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ *notice, this list of conditions and the following disclaimer in the
+ *documentation and/or other materials provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its
+ *contributors may be used to endorse or promote products derived from this
+ *software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
+ *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
+ *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
+ *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief
+ Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
+ the appropriate threadblock-scoped epilogue.
+
+ Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
+ accommodated by exchanging A and B operands and assuming transposed layouts. Partial
+ specializations here choose 'device::GemmTransposed' to implement this functionality.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+#include "cutlass/layout/matrix.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/epilogue/threadblock/epilogue.h"
+#include "cutlass/epilogue/thread/linear_combination.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/kernel/gemm_pipelined.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
+#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
+#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
+#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
+
+#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
+
+#include "kernel/b2b_gemm.h"
+#include "threadblock/default_b2b_mma.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace kernel {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// Element type for A matrix operand
+ typename ElementA_,
+ /// Layout type for A matrix operand
+ typename LayoutA_,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB_,
+ /// Layout type for B matrix operand
+ typename LayoutB_,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for C and D matrix operands
+ typename ElementC_,
+ /// Layout type for C and D matrix operands
+ typename LayoutC_,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Operator class tag
+ typename OperatorClass,
+ /// Tag indicating architecture to tune for
+ typename ArchTag,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Epilogue output operator
+ typename EpilogueOutputOp0,
+ /// Epilogue output operator
+ typename EpilogueOutputOp1,
+ /// Threadblock-level swizzling operator
+ typename ThreadblockSwizzle,
+ /// Number of stages used in the pipelined mainloop
+ int Stages,
+ /// If true, kernel is configured to support serial reduction in the epilogue
+ bool SplitKSerial,
+ /// Operation performed by GEMM
+ typename Operator,
+ /// Beta is zero or not
+ bool IsBetaZero = false
+>
+struct DefaultB2bGemm;
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization for Turing Architecture
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Layout type for B matrix operand
+ typename LayoutB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for C and D matrix operands
+ typename ElementC,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Epilogue output operator
+ typename EpilogueOutputOp0,
+ /// Epilogue output operator
+ typename EpilogueOutputOp1,
+ /// Threadblock-level swizzling operator
+ typename ThreadblockSwizzle,
+ /// If true, kernel is configured to support serial reduction in the epilogue
+ bool SplitKSerial,
+ /// Operation performed by GEMM
+ typename Operator
+>
+struct DefaultB2bGemm<
+ ElementA, LayoutA, kAlignmentA,
+ ElementB, LayoutB, kAlignmentB,
+ ElementC, layout::RowMajor,
+ ElementAccumulator,
+ arch::OpClassTensorOp,
+ arch::Sm75,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ EpilogueOutputOp0,
+ EpilogueOutputOp1,
+ ThreadblockSwizzle,
+ 2,
+ SplitKSerial,
+ Operator
+> {
+
+ /// Define the threadblock-scoped matrix multiply-accumulate
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
+ ElementA,
+ LayoutA,
+ kAlignmentA,
+ ElementB,
+ LayoutB,
+ kAlignmentB,
+ ElementAccumulator,
+ layout::RowMajor,
+ arch::OpClassTensorOp,
+ arch::Sm75,
+ ThreadblockShape0,
+ ThreadblockShape1,
+ WarpShape0,
+ WarpShape1,
+ InstructionShape,
+ 2,
+ Operator,
+ EpilogueOutputOp0
+ >::ThreadblockB2bMma;
+
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
+
+ /// Define the epilogue
+ using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
+ ThreadblockShape1,
+ typename B2bMma::Operator1,
+ kPartitionsK1,
+ EpilogueOutputOp1,
+ EpilogueOutputOp1::kCount
+ >::Epilogue;
+
+ /// Define the kernel-level GEMM operator.
+ using B2bGemmKernel = kernel::B2bGemm;
+};
+
+
+/// Partial specialization for Turing IMMA Interleaved layout
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for C and D matrix operands
+ typename ElementC,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Epilogue output operator
+ typename EpilogueOutputOp0,
+ /// Epilogue output operator
+ typename EpilogueOutputOp1,
+ /// Threadblock-level swizzling operator
+ typename ThreadblockSwizzle,
+ /// Number of Interleaved k
+ int InterleavedK,
+ /// If true, kernel is configured to support serial reduction in the
+ /// epilogue
+ bool SplitKSerial,
+ /// Operation performed by GEMM
+ typename Operator,
+ /// Is Beta zero or not
+ bool IsBetaZero>
+struct DefaultB2bGemm,
+ kAlignmentA, ElementB,
+ layout::RowMajorInterleaved, kAlignmentB,
+ ElementC, layout::ColumnMajorInterleaved,
+ int32_t, arch::OpClassTensorOp, arch::Sm75,
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
+ InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1,
+ ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
+ using LayoutA = layout::ColumnMajorInterleaved;
+ using LayoutB = layout::RowMajorInterleaved;
+ using LayoutC = layout::ColumnMajorInterleaved;
+
+ using ElementAccumulator = int32_t;
+
+ /// Define the threadblock-scoped matrix multiply-accumulate
+ using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<
+ ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
+ arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1,
+ WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma;
+
+ static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK;
+
+ /// Define the epilogue for the 2nd Gemm
+ using Epilogue = typename cutlass::epilogue::threadblock::
+ DefaultInterleavedEpilogueTensorOp<
+ ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1,
+ 64 / sizeof_bits::value, InterleavedK,
+ IsBetaZero>::Epilogue;
+
+ /// Define the kernel-level GEMM operator.
+ using B2bGemmKernel = kernel::B2bGemm;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace kernel
+} // namespace gemm
+} // namespace cutlass
diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h
new file mode 100644
index 00000000..01cca8b7
--- /dev/null
+++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_base.h
@@ -0,0 +1,230 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/arch/memory.h"
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/matrix_shape.h"
+#include "cutlass/numeric_types.h"
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math
+/// instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape0_,
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape1_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy0_,
+ /// Policy describing tuning details (concept: MmaPolicy)
+ typename Policy1_,
+ /// Number of stages,
+ int Stages,
+ /// Used for partial specialization
+ typename Enable = bool>
+class B2bMmaBase {
+ public:
+ ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using Shape0 = Shape0_;
+ using Shape1 = Shape1_;
+
+ ///< Policy describing tuning details
+ using Policy0 = Policy0_;
+ using Policy1 = Policy1_;
+
+ //
+ // Dependent types
+ //
+
+ /// Warp-level Mma
+ using Operator0 = typename Policy0::Operator;
+ using Operator1 = typename Policy1::Operator;
+
+ /// Shape describing the overall GEMM computed from shared memory
+ /// by each warp.
+ using WarpGemm0 = typename Policy0::Operator::Shape;
+ using WarpGemm1 = typename Policy1::Operator::Shape;
+
+ /// Shape describing the number of warps filling the CTA
+ using WarpCount0 = GemmShape;
+ using WarpCount1 = GemmShape;
+
+ /// Number of warp-level GEMM oeprations
+ static int const kWarpGemmIterations0 =
+ (WarpGemm0::kK / Operator0::Policy::MmaShape::kK);
+ static int const kWarpGemmIterations1 =
+ (WarpGemm1::kK / Operator1::Policy::MmaShape::kK);
+
+ /// Number of stages
+ static int const kStages = Stages;
+
+ //
+ // Nested structs
+ //
+
+ /// Shared storage object needed by threadblock-scoped GEMM
+ template<
+ typename Shape_,
+ typename Policy_
+ >
+ class SharedStorage {
+ public:
+ //
+ // Type definitions
+ //
+ using Shape = Shape_;
+ using Policy = Policy_;
+ using Operator = typename Policy::Operator;
+
+ /// Tensor reference to the A operand
+ using TensorRefA = TensorRef;
+
+ /// Tensor reference to the B operand
+ using TensorRefB = TensorRef;
+
+
+ /// Shape of the A matrix operand in shared memory
+ using ShapeA = MatrixShape;
+
+ /// Shape of the B matrix operand in shared memory
+ using ShapeB =
+ MatrixShape;
+
+ public:
+ //
+ // Data members
+ //
+
+ /// Buffer for A operand
+ AlignedBuffer operand_A;
+
+ /// Buffer for B operand
+ AlignedBuffer operand_B;
+
+ public:
+
+ //
+ // Methods
+ //
+
+ /// Returns a layout object for the A matrix
+ CUTLASS_DEVICE
+ static typename Operator::LayoutA LayoutA() {
+ return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
+ }
+
+ /// Returns a layout object for the B matrix
+ CUTLASS_HOST_DEVICE
+ static typename Operator::LayoutB LayoutB() {
+ return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
+ }
+
+ /// Returns a TensorRef to the A operand
+ CUTLASS_HOST_DEVICE
+ TensorRefA operand_A_ref() {
+ return TensorRefA{operand_A.data(), LayoutA()};
+ }
+
+ /// Returns a TensorRef to the B operand
+ CUTLASS_HOST_DEVICE
+ TensorRefB operand_B_ref() {
+ return TensorRefB{operand_B.data(), LayoutB()};
+ }
+ };
+
+ using SharedStorage0 = SharedStorage;
+ using SharedStorage1 = SharedStorage;
+ union B2bMmaSharedStorage {
+ SharedStorage0 sharedStorage0;
+ SharedStorage1 sharedStorage1;
+ };
+
+
+ protected:
+
+ //
+ // Data members
+ //
+
+ /// Iterator to load a warp-scoped tile of A0 operand from shared memory
+ typename Operator0::IteratorA warp_tile_iterator_A0_;
+
+ /// Iterator to load a warp-scoped tile of B0 operand from shared memory
+ typename Operator0::IteratorB warp_tile_iterator_B0_;
+
+ /// Iterator to load a warp-scoped tile of B0 operand from shared memory
+ typename Operator1::IteratorB warp_tile_iterator_B1_;
+
+public:
+
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ B2bMmaBase(
+ ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ B2bMmaSharedStorage &shared_storage,
+ ///< ID within the threadblock
+ int thread_idx,
+ ///< ID of warp
+ int warp_idx,
+ ///< ID of each thread within a warp
+ int lane_idx
+ ):
+ warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),
+ warp_tile_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), lane_idx),
+ warp_tile_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), lane_idx) {
+
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h
new file mode 100644
index 00000000..ca89cf0b
--- /dev/null
+++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h
@@ -0,0 +1,509 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/array.h"
+#include "cutlass/aligned_buffer.h"
+#include "cutlass/numeric_conversion.h"
+
+#include "cutlass/numeric_types.h"
+#include "cutlass/matrix_shape.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
+
+#include "threadblock/b2b_mma_base.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////////////////////
+template
+struct chk_val {
+ static_assert(a==0, "check value");
+};
+
+/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
+template <
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape0_,
+ /// Iterates over tiles of A operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
+ typename IteratorA0_,
+ /// Iterates over tiles of A operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorA0_,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
+ typename IteratorB0_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB0_,
+ /// Size of the Gemm problem - concept: gemm::GemmShape<>
+ typename Shape1_,
+ /// Iterates over the intermediate accumulator tile
+ // (concept::MmaTensorOpFragmentIterator)
+ typename FragmentIteratorA1_,
+ /// Iterates over tiles of B operand in global memory
+ // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
+ typename IteratorB1_,
+ /// Iterates over tiles of B operand in shared memory
+ /// (concept: WriteableTileIterator | RandomAccessTileIterator)
+ typename SmemIteratorB1_,
+ /// Data type of accumulator matrix
+ typename ElementC_,
+ /// Data type of accumulator matrix
+ typename LayoutC_,
+ /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...)
+ typename OutputOp_,
+ /// Policy describing tuning details (concept: MmaPipelinedPolicy)
+ typename Policy0_,
+ /// Policy describing tuning details (concept: MmaPipelinedPolicy)
+ typename Policy1_,
+ /// Transformation applied to A0 operand
+ typename TransformA0_ = NumericArrayConverter<
+ typename SmemIteratorA0_::Element,
+ typename IteratorA0_::Element,
+ IteratorA0_::Fragment::kElements>,
+ ///
+ /// Transformation applied to B0 operand
+ typename TransformB0_ = NumericArrayConverter<
+ typename SmemIteratorB0_::Element,
+ typename IteratorB0_::Element,
+ IteratorB0_::Fragment::kElements>,
+ ///
+ /// Transformation applied to B1 operand
+ typename TransformB1_ = NumericArrayConverter<
+ typename SmemIteratorB1_::Element,
+ typename IteratorB1_::Element,
+ IteratorB1_::Fragment::kElements>,
+ /// Used for partial specialization
+ typename Enable = bool
+>
+class B2bMmaPipelined : public B2bMmaBase {
+public:
+
+ ///< Base class
+ using Base = B2bMmaBase;
+
+ using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory
+ using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory
+ using Policy0 = Policy0_; ///< Policy describing tuning details
+
+ using SmemIteratorA0 = SmemIteratorA0_;
+ using SmemIteratorB0 = SmemIteratorB0_;
+
+ using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
+ using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile
+ using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory
+ using Policy1 = Policy1_; ///< Policy describing tuning details
+
+ using SmemIteratorB1 = SmemIteratorB1_;
+
+
+ using ElementC = ElementC_; ///< Data type of accumulator matrix
+ using LayoutC = LayoutC_; ///< Layout of accumulator matrix
+
+ using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm
+
+ using TransformA0 = TransformA0_;
+ using TransformB0 = TransformB0_;
+ using TransformB1 = TransformB1_;
+
+ //
+ // Dependent types
+ //
+
+ /// Fragment of operand A loaded from global memory
+ using FragmentA0 = typename IteratorA0::Fragment;
+
+ /// Fragment of operand B loaded from global memory
+ using FragmentB0 = typename IteratorB0::Fragment;
+
+ /// Fragment of accumulator tile
+ using FragmentC0 = typename Policy0::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator0 = typename Policy0::Operator;
+
+ /// Fragment of operand B loaded from global memory
+ using FragmentB1 = typename IteratorB1::Fragment;
+
+ /// Fragment of accumulator tile
+ using FragmentC1 = typename Policy1::Operator::FragmentC;
+
+ /// Warp-level Mma
+ using Operator1 = typename Policy1::Operator;
+
+ /// Obtain the arch tag from the warp-level operator
+ using ArchTag = typename Policy0::Operator::ArchTag;
+
+ /// Complex transform on A0 operand
+ static ComplexTransform const kTransformA0 = Operator0::kTransformA;
+
+ /// Complex transform on B0 operand
+ static ComplexTransform const kTransformB0 = Operator0::kTransformB;
+
+ /// Complex transform on B1 operand
+ static ComplexTransform const kTransformB1 = Operator1::kTransformB;
+
+ // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
+ static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
+
+private:
+
+ using WarpFragmentA0 = typename Operator0::FragmentA;
+ using WarpFragmentB0 = typename Operator0::FragmentB;
+ /// Warp Fragment of operand A1 loaded from accmulator tile
+ using WarpFragmentA1 = typename FragmentIteratorA1::Fragment;
+ using WarpFragmentB1 = typename Operator1::FragmentB;
+
+protected:
+
+ /// Iterator to write threadblock-scoped tile of A operand to shared memory
+ SmemIteratorA0 smem_iterator_A_;
+
+ /// Iterator to write threadblock-scoped tile of B0 operand to shared memory
+ SmemIteratorB0 smem_iterator_B0_;
+
+ /// Iterator to write threadblock-scoped tile of B1 operand to shared memory
+ SmemIteratorB1 smem_iterator_B1_;
+
+public:
+
+ /// Construct from tensor references
+ CUTLASS_DEVICE
+ B2bMmaPipelined(
+ typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
+ int thread_idx, ///< ID within the threadblock
+ int warp_idx, ///< ID of warp
+ int lane_idx ///< ID of each thread within a warp
+ ):
+ Base(shared_storage, thread_idx, warp_idx, lane_idx),
+ smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),
+ smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx),
+ smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) {
+
+
+ // Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
+ // _m: the warp's position within the threadblock along the M dimension
+ // _n: the warp's position within the threadblock along the N dimension
+ // _k: the warp's position within the threadblock along the K dimension
+
+ //These should stay the same across different GEMM layers
+ int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);
+ int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);
+
+ int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;
+ int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;
+
+ //These may change across different GEMM layers
+ int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k;
+ int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k;
+
+ // Add per-warp offsets in units of warp-level tiles
+ this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0});
+ this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n});
+ this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n});
+ }
+
+ /// Perform a threadblock-scoped matrix multiply-accumulate
+ CUTLASS_DEVICE
+ void operator()(
+ int gemm_k_iterations_0, ///< number of iterations of the mainloop
+ FragmentC1 &accum, ///< destination accumulator tile
+ IteratorA0 iterator_A, ///< iterator over A operand in global memory
+ IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory
+ IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory
+ FragmentC0 const &src_accum, ///< source accumualtor tile
+ OutputOp output_op_0, ///< epilogue operation after 1st Gemm
+ TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment
+ TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment
+ TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment
+
+ //
+ // Prologue
+ //
+
+ // Perform accumulation in the 'd' output operand
+ FragmentC0 accum0 = src_accum;
+
+ FragmentA0 tb_frag_A;
+ FragmentB0 tb_frag_B0;
+
+ tb_frag_A.clear();
+ tb_frag_B0.clear();
+
+ // The last kblock is loaded in the prolog
+ iterator_A.load(tb_frag_A);
+ iterator_B0.load(tb_frag_B0);
+
+ ++iterator_A;
+ ++iterator_B0;
+
+ this->smem_iterator_A_.store(tb_frag_A);
+ this->smem_iterator_B0_.store(tb_frag_B0);
+
+ ++this->smem_iterator_A_;
+ ++this->smem_iterator_B0_;
+
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math instructions
+ WarpFragmentA0 warp_frag_A0[2];
+ WarpFragmentB0 warp_frag_B0[2];
+
+ this->warp_tile_iterator_A0_.set_kgroup_index(0);
+ this->warp_tile_iterator_B0_.set_kgroup_index(0);
+
+ this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);
+ this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);
+
+ ++this->warp_tile_iterator_A0_;
+ ++this->warp_tile_iterator_B0_;
+
+ Operator0 warp_mma0;
+
+ int smem_write_stage_idx = 1;
+
+ // Avoid reading out of bounds
+ if (gemm_k_iterations_0 <= 1) {
+ iterator_A.clear_mask();
+ iterator_B0.clear_mask();
+ }
+
+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
+ // shared memory loads (which have the tighest latency requirement).
+ iterator_A.load(tb_frag_A);
+
+ //
+ // Mainloop
+ //
+
+ // Note: The main loop does not support Base::WarpGemmIterations == 2.
+ CUTLASS_GEMM_LOOP
+ for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {
+
+ //
+ // Loop over GEMM K dimension
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {
+
+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
+ // as the case may be.
+
+ if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {
+
+ // Write fragments to shared memory
+ this->smem_iterator_A_.store(tb_frag_A);
+
+ this->smem_iterator_B0_.store(tb_frag_B0);
+
+ __syncthreads();
+
+ // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
+ // shared memory loads (which have the tighest latency requirement).
+ iterator_A.load(tb_frag_A);
+
+ ++this->smem_iterator_B0_;
+ ++this->smem_iterator_A_;
+
+
+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
+ if (smem_write_stage_idx == 1) {
+ this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
+ this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0});
+ }
+ else {
+ this->warp_tile_iterator_A0_.add_tile_offset(
+ {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0});
+ this->warp_tile_iterator_B0_.add_tile_offset(
+ {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0,
+ 0});
+ }
+
+ smem_write_stage_idx ^= 1;
+ }
+
+ this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
+ this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);
+
+ this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);
+ this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);
+
+ ++this->warp_tile_iterator_A0_;
+ ++this->warp_tile_iterator_B0_;
+
+ if (warp_mma_k == 0) {
+
+ iterator_B0.load(tb_frag_B0);
+
+ ++iterator_A;
+ ++iterator_B0;
+
+ // Avoid reading out of bounds if this was the last loop iteration
+ if (gemm_k_iterations_0 <= 2) {
+ iterator_A.clear_mask();
+ iterator_B0.clear_mask();
+ }
+ }
+
+ warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);
+ }
+ }
+
+ //2nd Gemm
+
+ /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile
+ FragmentIteratorA1 warp_tile_iterator_A1_(accum0);
+
+ //
+ // Prologue
+ //
+
+ FragmentB1 tb_frag_B1;
+
+ tb_frag_B1.clear();
+
+ // The last kblock is loaded in the prolog
+ iterator_B1.load(tb_frag_B1);
+
+ ++iterator_B1;
+
+ this->smem_iterator_B1_.store(tb_frag_B1);
+
+ ++this->smem_iterator_B1_;
+
+ __syncthreads();
+
+ // Pair of fragments used to overlap shared memory loads and math instructions
+ WarpFragmentA1 warp_frag_A1[2];
+ WarpFragmentB1 warp_frag_B1[2];
+
+ //warp_tile_iterator_A1_.set_kgroup_index(0);
+ this->warp_tile_iterator_B1_.set_kgroup_index(0);
+
+ warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0);
+ this->warp_tile_iterator_B1_.load(warp_frag_B1[0]);
+
+ ++warp_tile_iterator_A1_;
+ ++this->warp_tile_iterator_B1_;
+
+ Operator1 warp_mma1;
+
+ smem_write_stage_idx = 1;
+
+ int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1;
+
+ // Avoid reading out of bounds
+ if (gemm_k_iterations_1 <= 1) {
+ iterator_B1.clear_mask();
+ }
+
+ //
+ // Mainloop
+ //
+
+ // Note: The main loop does not support Base::WarpGemmIterations == 2.
+ CUTLASS_PRAGMA_UNROLL
+ for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) {
+
+ //
+ // Loop over GEMM K dimension
+ //
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) {
+
+ // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
+ // as the case may be.
+
+ if (warp_mma_k == Base::kWarpGemmIterations1 - 1) {
+
+ // Write fragments to shared memory
+
+ this->smem_iterator_B1_.store(tb_frag_B1);
+
+ __syncthreads();
+ ++smem_iterator_B1_;
+
+ // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
+ if (smem_write_stage_idx == 1) {
+ smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
+ }
+ else {
+ this->warp_tile_iterator_B1_.add_tile_offset(
+ {-Base::kStages * Policy1::kPartitionsK *
+ Base::kWarpGemmIterations1,
+ 0});
+ }
+
+ smem_write_stage_idx ^= 1;
+ }
+
+ this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1);
+
+ warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0);
+ this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]);
+
+
+ ++warp_tile_iterator_A1_;
+ ++this->warp_tile_iterator_B1_;
+
+ if (warp_mma_k == 0) {
+
+ iterator_B1.load(tb_frag_B1);
+ ++iterator_B1;
+
+
+ // Avoid reading out of bounds if this was the last loop iteration
+ if (gemm_k_iterations_1 <= 2) {
+ iterator_B1.clear_mask();
+ }
+ }
+
+ warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum);
+ }
+ }
+
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
diff --git a/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h
new file mode 100644
index 00000000..cd1403c7
--- /dev/null
+++ b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h
@@ -0,0 +1,289 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+#include "cutlass/arch/arch.h"
+
+#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
+#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
+#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
+#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
+
+#include "threadblock/b2b_mma_pipelined.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// Element type for A matrix operand
+ typename ElementA_,
+ /// Layout type for A matrix operand
+ typename LayoutA_,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB_,
+ /// Layout type for B matrix operand
+ typename LayoutB_,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for internal accumulation
+ typename ElementAccumulator_,
+ /// Layout type for C and D matrix operands
+ typename LayoutC_,
+ /// Operator class tag
+ typename OperatorClass_,
+ /// Tag indicating architecture to tune for
+ typename ArchTag_,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0_,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1_,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0_,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1_,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape_,
+ /// Number of stages used in the pipelined mainloop
+ int Stages,
+ /// Operation perfomed by GEMM
+ typename Operator,
+ /// Epilogue output operator
+ typename EpilogueOutputOp,
+ /// Store the accumulators in row major or column major. Row major is used
+ /// when output layout is interleaved.
+ bool AccumulatorsInRowMajor = false>
+struct DefaultB2bMma;
+
+////////////////////////////////////////////////////////////////////////////////
+/// Specialization for row-major output
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Layout type for B matrix operand
+ typename LayoutB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Tag indicating architecture to tune for
+ typename OperatorClass,
+ /// Tag indicating architecture to tune for
+ typename ArchTag,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Operation performed by GEMM
+ typename Operator,
+ /// Epilogue output operator
+ typename EpilogueOutputOp>
+struct DefaultB2bMma {
+ // Define the MmaCore components
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
+ OperatorClass, 2, Operator>;
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
+ ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
+ OperatorClass, 2, Operator>;
+
+ // Define iterators over tiles from the A operand
+ using IteratorA0 =
+ cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape,
+ ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>;
+
+ // Define iterators over tiles from the B operand
+ using IteratorB0 =
+ cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape,
+ ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>;
+
+ // Use fragment iterator for A operand
+ using AccumulatorLayout = cutlass::layout::ColumnMajor;
+ using FragmentIteratorA1 =
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
+ cutlass::MatrixShape, //warp shape
+ cutlass::MatrixShape, //accumulator shape
+ MmaCore1::Shape::kK, //kBlocksColumn
+ ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp, true>;
+
+ // Define iterators over tiles from the B operand
+ using IteratorB1 =
+ cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape,
+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
+
+ // Define the threadblock-scoped pipelined matrix multiply
+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
+ IteratorB0, typename MmaCore0::SmemIteratorB,
+ typename MmaCore1::Shape, FragmentIteratorA1,
+ IteratorB1, typename MmaCore1::SmemIteratorB,
+ ElementAccumulator, layout::RowMajor,
+ EpilogueOutputOp,
+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
+
+};
+////////////////////////////////////////////////////////////////////////////////
+
+/// Specialization for column-major-interleaved output
+template <
+ /// Element type for A matrix operand
+ typename ElementA,
+ /// Layout type for A matrix operand
+ typename LayoutA,
+ /// Access granularity of A matrix in units of elements
+ int kAlignmentA,
+ /// Element type for B matrix operand
+ typename ElementB,
+ /// Layout type for B matrix operand
+ typename LayoutB,
+ /// Access granularity of B matrix in units of elements
+ int kAlignmentB,
+ /// Element type for internal accumulation
+ typename ElementAccumulator,
+ /// Tag indicating architecture to tune for
+ typename OperatorClass,
+ /// Tag indicating architecture to tune for
+ typename ArchTag,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape0,
+ /// Threadblock-level tile size (concept: GemmShape)
+ typename ThreadblockShape1,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape0,
+ /// Warp-level tile size (concept: GemmShape)
+ typename WarpShape1,
+ /// Instruction-level tile size (concept: GemmShape)
+ typename InstructionShape,
+ /// Operation performed by GEMM
+ typename Operator,
+ /// Epilogue output operator
+ typename EpilogueOutputOp,
+ /// Number of Interleaved K
+ int InterleavedK>
+struct DefaultB2bMma, OperatorClass, ArchTag,
+ ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1,
+ InstructionShape, 2, Operator, EpilogueOutputOp, true> {
+ // Define the MmaCore components
+ using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore<
+ ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA,
+ ElementB, LayoutB, ElementAccumulator,
+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator,
+ true>;
+ using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore<
+ ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA,
+ ElementB, LayoutB, ElementAccumulator,
+ layout::ColumnMajorInterleaved, OperatorClass, 2, Operator,
+ true>;
+
+ static_assert(kAlignmentA == 128 / sizeof_bits::value,
+ "Alignment must match thread data map's vector length");
+
+ static_assert(kAlignmentB ==128 / sizeof_bits::value,
+ "Alignment must match thread data map's vector length");
+
+ // Define iterators over tiles from the A operand
+ using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape, ElementA,
+ LayoutA, 1, typename MmaCore0::IteratorThreadMapA>;
+
+ // Define iterators over tiles from the B operand
+ using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape, ElementB,
+ LayoutB, 0, typename MmaCore0::IteratorThreadMapB>;
+
+ // Use fragment iterator for A operand
+ using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true
+ using FragmentIteratorA1 =
+ cutlass::gemm::warp::MmaTensorOpFragmentIterator<
+ cutlass::MatrixShape, //warp shape
+ cutlass::MatrixShape, //accumulator shape
+ MmaCore1::Shape::kK, //kBlocksColumn
+ ElementAccumulator, ElementA, AccumulatorLayout,
+ InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>;
+
+ // Define iterators over tiles from the B operand
+ using IteratorB1 =
+ cutlass::transform::threadblock::PredicatedTileIterator<
+ cutlass::MatrixShape,
+ ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>;
+
+
+
+ // Define the threadblock-scoped pipelined matrix multiply
+ using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined<
+ typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA,
+ IteratorB0, typename MmaCore0::SmemIteratorB,
+ typename MmaCore1::Shape, FragmentIteratorA1,
+ IteratorB1, typename MmaCore1::SmemIteratorB,
+ ElementAccumulator, layout::ColumnMajorInterleaved,
+ EpilogueOutputOp,
+ typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace gemm
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index d5c503e9..3da7ae45 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@@ -60,6 +60,8 @@ foreach(EXAMPLE
08_turing_tensorop_gemm
10_planar_complex
11_planar_complex_array
+ 12_gemm_bias_relu
+ 13_fused_two_gemms
)
add_subdirectory(${EXAMPLE})
diff --git a/include/cutlass/aligned_buffer.h b/include/cutlass/aligned_buffer.h
index 3232ef87..8b3bb071 100644
--- a/include/cutlass/aligned_buffer.h
+++ b/include/cutlass/aligned_buffer.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h
index b38a347a..faf01cc6 100644
--- a/include/cutlass/arch/arch.h
+++ b/include/cutlass/arch/arch.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -52,6 +52,10 @@ struct Sm72 {
struct Sm75 {
static int const kMinComputeCapability = 75;
};
+struct Sm80 {
+ static int const kMinComputeCapability = 80;
+};
+
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
diff --git a/include/cutlass/arch/cache_operation.h b/include/cutlass/arch/cache_operation.h
new file mode 100644
index 00000000..646b51de
--- /dev/null
+++ b/include/cutlass/arch/cache_operation.h
@@ -0,0 +1,60 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Directives related to cache operations
+*/
+#pragma once
+
+#include "cutlass/cutlass.h"
+
+namespace cutlass {
+namespace arch {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Controls PTX cache operations
+struct CacheOperation {
+ enum Kind {
+ /// Cache at all levels - accessed again
+ Always,
+ /// Cache at global level
+ Global,
+ /// Streaming - likely to be accessed once
+ Streaming,
+ /// Indicates the line will not be used again
+ LastUse,
+ /// Don't cache, and fetch again
+ Volatile,
+ /// Write back at all coherent levels
+ WriteBack,
+ /// Write through to system memory
+ WriteThrough
+ };
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace arch
+} // namespace cutlass
diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h
index fc939053..48ef02cd 100644
--- a/include/cutlass/arch/memory.h
+++ b/include/cutlass/arch/memory.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -28,13 +28,271 @@
#pragma once
+#include "cutlass/cutlass.h"
+
namespace cutlass {
namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
+template <
+ /// Fragment type to store loaded data
+ typename AccessType,
+ /// The bytes of loading
+ int LoadBytes
+ >
+struct global_load;
/////////////////////////////////////////////////////////////////////////////////////////////////
+//
+// Specializations
+//
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ uint4 *data = reinterpret_cast(&D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %9, 0;\n"
+ " mov.b32 %0, %10;\n"
+ " mov.b32 %1, %11;\n"
+ " mov.b32 %2, %12;\n"
+ " mov.b32 %3, %13;\n"
+ " mov.b32 %4, %14;\n"
+ " mov.b32 %5, %15;\n"
+ " mov.b32 %6, %16;\n"
+ " mov.b32 %7, %17;\n"
+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
+ " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%18];\n"
+ "}\n"
+ : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w),
+ "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w)
+ : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y),
+ "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y),
+ "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16));
+ }
+};
+
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ uint4 &data = reinterpret_cast(D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %5, 0;\n"
+ " mov.b32 %0, %6;\n"
+ " mov.b32 %1, %7;\n"
+ " mov.b32 %2, %8;\n"
+ " mov.b32 %3, %9;\n"
+ " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
+ "}\n"
+ : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w));
+ }
+};
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ uint2 &data = reinterpret_cast(D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %3, 0;\n"
+ " mov.b32 %0, %4;\n"
+ " mov.b32 %1, %5;\n"
+ " @p ld.global.v2.u32 {%0, %1}, [%2];\n"
+ "}\n"
+ : "=r"(data.x), "=r"(data.y)
+ : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y));
+ }
+};
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ unsigned &data = reinterpret_cast(D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %2, 0;\n"
+ " mov.b32 %0, %3;\n"
+ " @p ld.global.u32 %0, [%1];\n"
+ "}\n"
+ : "=r"(data)
+ : "l"(ptr), "r"((int)pred_guard), "r"(data));
+ }
+};
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ uint16_t &data = reinterpret_cast(D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %2, 0;\n"
+ " mov.b16 %0, %3;\n"
+ " @p ld.global.u16 %0, [%1];\n"
+ "}\n"
+ : "=h"(data)
+ : "l"(ptr), "r"((int)pred_guard), "h"(data));
+ }
+};
+
+template
+struct global_load {
+ CUTLASS_DEVICE
+ global_load(AccessType &D, void const *ptr, bool pred_guard) {
+ if (pred_guard) D = *(reinterpret_cast(ptr));
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// Fragment type to store loaded data
+ typename AccessType,
+ /// The bytes of loading
+ int LoadBytes
+ >
+struct global_store;
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+//
+// Specializations
+//
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ uint4 const *data = reinterpret_cast(&D);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %5, 0;\n"
+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
+ " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
+ "}\n"
+ :
+ : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
+ "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
+ "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
+ }
+};
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ uint4 const &data = reinterpret_cast(D);
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %5, 0;\n"
+ " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
+ "}\n"
+ :
+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), "r"((int)pred_guard));
+ }
+};
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ uint2 const &data = reinterpret_cast(D);
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %3, 0;\n"
+ " @p st.global.v2.u32 [%0], {%1, %2};\n"
+ "}\n"
+ :
+ : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard));
+ }
+};
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ uint32_t const &data = reinterpret_cast(D);
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %2, 0;\n"
+ " @p st.global.u32 [%0], %1;\n"
+ "}\n"
+ :
+ : "l"(ptr), "r"(data), "r"((int)pred_guard));
+ }
+};
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ uint16_t const &data = reinterpret_cast(D);
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %2, 0;\n"
+ " @p st.global.u16 [%0], %1;\n"
+ "}\n"
+ :
+ : "l"(ptr), "h"(data), "r"((int)pred_guard));
+ }
+};
+
+template
+struct global_store {
+ CUTLASS_DEVICE
+ global_store(AccessType const &D, void *ptr, bool pred_guard) {
+ if (pred_guard) *(reinterpret_cast(ptr)) = D;
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace arch
} // namespace cutlass
@@ -42,4 +300,6 @@ namespace arch {
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "memory_sm75.h"
+#include "memory_sm80.h"
+
/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h
index 195f8abf..3fd121b9 100644
--- a/include/cutlass/arch/memory_sm75.h
+++ b/include/cutlass/arch/memory_sm75.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -50,20 +50,20 @@ inline __device__ void ldsm(Array & D, void const* ptr);
//
/////////////////////////////////////////////////////////////////////////////////////////////////
-#if ! defined(CUDA_LDMATRIX_SUPPORTED)
- #define CUDA_LDMATRIX_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 2))
+#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11)
+
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
+#define CUDA_LDMATRIX_ACTIVATED 1
#endif
-#if ! defined(CUDA_LDMATRIX_ENABLED)
- #define CUDA_LDMATRIX_ENABLED CUDA_LDMATRIX_SUPPORTED
-#endif
-
-#if CUDA_LDMATRIX_ENABLED && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)
- #define CUDA_LDMATRIX_ACTIVATED 1
+#define CUDA_LDMATRIX_SUPPORTED 1
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
-
+/*
+#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10)
+ #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1
+#endif
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED)
#define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1))
#endif
@@ -71,8 +71,9 @@ inline __device__ void ldsm(Array & D, void const* ptr);
#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED)
#define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED
#endif
+*/
-#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
+#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
extern "C" {
//
// This NVVM intrinsic is subject to change in future versions of CUDA.
@@ -85,19 +86,49 @@ inline __device__ void ldsm(Array & D, void const* ptr);
/////////////////////////////////////////////////////////////////////////////////////////////////
-#if CUDA_NVVM_GET_SMEM_POINTER_ENABLED
+/// CUTLASS helper to get SMEM pointer
+inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
+
+// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to
+// the previous internal intrinsics if they are available.
+#if (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11)
+ //
+ // This NVVM intrinsic converts an address in shared memory to a plain
+ // unsigned integer. This is necessary to pass to shared memory instructions
+ // in inline PTX.
+ //
+ // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2].
+ //
+ //__device__ size_t __cvta_generic_to_shared(void* ptr);
/// CUTLASS helper to get SMEM pointer
- inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
- return __nvvm_get_smem_pointer(const_cast(ptr));
- }
+ return static_cast(__cvta_generic_to_shared(ptr));
- /// CUTLASS helper to get SMEM pointer
- inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) {
- return __nvvm_get_smem_pointer(ptr);
- }
+#elif (defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)
+ return __nvvm_get_smem_pointer(ptr);
+
+#elif defined(__CUDA_ARCH__)
+
+ uint32_t smem_ptr;
+
+ asm(
+ "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
+ : "=r"(smem_ptr) : "l"(ptr));
+
+ return smem_ptr;
+
+#else
+
+ return 0;
#endif
+}
+
+/// CUTLASS helper to get SMEM pointer
+inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) {
+ return cutlass_get_smem_pointer(const_cast(ptr));
+}
+
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
@@ -235,5 +266,6 @@ inline __device__ void ldsm(
}
/////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace arch
} // namespace cutlass
diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h
new file mode 100644
index 00000000..04c56876
--- /dev/null
+++ b/include/cutlass/arch/memory_sm80.h
@@ -0,0 +1,238 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/*! \file
+ \brief Architecture-specific operators on memory added for SM80
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/arch/memory_sm75.h"
+#include "cutlass/arch/cache_operation.h"
+
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
+ #define CUDA_CP_ASYNC_ACTIVATED 1
+#else
+ #define CUDA_CP_ASYNC_ACTIVATED 0
+#endif
+
+namespace cutlass {
+namespace arch {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Initiates an asynchronous copy from global memory to shared memory.
+///
+/// LDGSTS
+///
+template <
+ /// Size of the access in bytes
+ int SizeInBytes,
+ /// Cache operation
+ CacheOperation::Kind cache_op = CacheOperation::Always>
+struct cp_async;
+
+/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
+/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
+///
+/// LDGSTS
+///
+template <
+ /// Size of the access in bytes
+ int SizeInBytes,
+ /// Cache operation
+ CacheOperation::Kind cache_op = CacheOperation::Always>
+struct cp_async_zfill;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization
+template <
+ /// Size of the access in bytes
+ int SizeInBytes>
+struct cp_async {
+ /// Copy
+ CUTLASS_DEVICE
+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
+ #if CUDA_CP_ASYNC_ACTIVATED
+
+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %0, 0;\n"
+ " @p cp.async.ca.shared.global [%1], [%2], %3;\n"
+ "}\n" ::"r"((int)pred_guard),
+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
+
+ #else
+ using AccessType = Array;
+
+ if (pred_guard) {
+ *static_cast(smem_ptr) = *static_cast(global_ptr);
+ }
+ #endif
+ }
+};
+
+/// Partial specialization
+template <
+ /// Size of the access in bytes
+ int SizeInBytes>
+struct cp_async_zfill {
+ /// Copy with zero fill
+ CUTLASS_DEVICE
+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
+ #if CUDA_CP_ASYNC_ACTIVATED
+
+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
+ int src_in_bytes = (pred_guard ? SizeInBytes : 0);
+
+ asm volatile(
+ "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
+
+ #else
+ using AccessType = Array;
+
+ if (pred_guard) {
+ *static_cast(smem_ptr) = *static_cast(global_ptr);
+ }
+ else {
+ AccessType zeros;
+ zeros.clear();
+ *static_cast(smem_ptr) = zeros;
+ }
+ #endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization
+template <
+ /// Size of the access in bytes
+ int SizeInBytes>
+struct cp_async {
+ /// Copy
+ CUTLASS_DEVICE
+ cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
+ #if CUDA_CP_ASYNC_ACTIVATED
+
+ static_assert(SizeInBytes == 16,
+ "cp.async only supports CacheOperation::Global when access size is 16B.");
+
+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
+
+ asm volatile(
+ "{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %0, 0;\n"
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+ "}\n" ::"r"((int)pred_guard),
+ "r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
+
+ #else
+ using AccessType = Array;
+
+ if (pred_guard) {
+ *static_cast(smem_ptr) = *static_cast(global_ptr);
+ }
+ #endif
+ }
+};
+
+/// Partial specialization
+template <
+ /// Size of the access in bytes
+ int SizeInBytes>
+struct cp_async_zfill {
+ /// Copy with zero fill
+ CUTLASS_DEVICE
+ cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
+ #if CUDA_CP_ASYNC_ACTIVATED
+
+ static_assert(SizeInBytes == 16,
+ "cp.async only supports CacheOperation::Global when access size is 16B.");
+
+ unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
+ int src_in_bytes = (pred_guard ? SizeInBytes : 0);
+
+ asm volatile(
+ "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
+ "l"(global_ptr), "n"(SizeInBytes), "r"(src_in_bytes));
+
+ #else
+ using AccessType = Array;
+
+ if (pred_guard) {
+ *static_cast(smem_ptr) = *static_cast(global_ptr);
+ }
+ else {
+ AccessType zeros;
+ zeros.clear();
+ *static_cast(smem_ptr) = zeros;
+ }
+ #endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
+CUTLASS_DEVICE
+void cp_async_fence() {
+ #if CUDA_CP_ASYNC_ACTIVATED
+ asm volatile("cp.async.commit_group;\n" ::);
+ #endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Blocks until all but previous cp.async.commit_group operations have committed.
+template
+CUTLASS_DEVICE void cp_async_wait() {
+ #if CUDA_CP_ASYNC_ACTIVATED
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
+ #endif
+}
+
+/// Blocks until all previous cp.async.commit_group operations have committed.
+template <>
+CUTLASS_DEVICE void cp_async_wait<0>() {
+ #if CUDA_CP_ASYNC_ACTIVATED
+ asm volatile("cp.async.wait_all;\n" ::);
+ #endif
+}
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace arch
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h
index e59b710f..74c24695 100644
--- a/include/cutlass/arch/mma.h
+++ b/include/cutlass/arch/mma.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@@ -51,11 +51,26 @@ struct OpMultiplyAddSaturate;
/////////////////////////////////////////////////////////////////////////////////////////////////
+/// Tag indicating the input is converted to a narrower type (BF16)
+struct OpMultiplyAddFastBF16;
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Tag indicating the input is converted to a narrower type (F16)
+struct OpMultiplyAddFastF16;
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
/// Tag indicating the complex multiply-add operation
struct OpMultiplyAddComplex;
/////////////////////////////////////////////////////////////////////////////////////////////////
+/// Tag indicating the gaussian complex multiply-add operation
+struct OpMultiplyAddGaussianComplex;
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
/// Tag indicating the inner product is defined by (XOR, POPC)
struct OpXorPopc;
diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h
index 8698a8b3..fce521dc 100644
--- a/include/cutlass/arch/mma_sm50.h
+++ b/include/cutlass/arch/mma_sm50.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/mma_sm60.h b/include/cutlass/arch/mma_sm60.h
index 6e513ced..ab0481ae 100644
--- a/include/cutlass/arch/mma_sm60.h
+++ b/include/cutlass/arch/mma_sm60.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/mma_sm61.h b/include/cutlass/arch/mma_sm61.h
index 68a1b145..9ec8857e 100644
--- a/include/cutlass/arch/mma_sm61.h
+++ b/include/cutlass/arch/mma_sm61.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h
index 57b50e00..b03ce2c1 100644
--- a/include/cutlass/arch/mma_sm70.h
+++ b/include/cutlass/arch/mma_sm70.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h
index fb8a3dc5..ef65f20b 100644
--- a/include/cutlass/arch/mma_sm75.h
+++ b/include/cutlass/arch/mma_sm75.h
@@ -1,5 +1,5 @@
/***************************************************************************************************
- * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h
new file mode 100644
index 00000000..445ec388
--- /dev/null
+++ b/include/cutlass/arch/mma_sm80.h
@@ -0,0 +1,2091 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Matrix multiply
+*/
+
+#pragma once
+
+#if defined(__CUDACC_RTC__)
+#include
+#else
+#include
+#endif
+
+#include "mma.h"
+#include "cutlass/layout/matrix.h"
+#include "cutlass/numeric_types.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
+
+#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1
+
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
+#define CUTLASS_ARCH_MMA_SM80_ENABLED
+#endif
+#endif
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace arch {
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 1688 - Float BF16, FP32 accumulation
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32
+template <>
+struct Mma<
+ gemm::GemmShape<16, 8, 8>,
+ 32,
+ bfloat16_t,
+ layout::RowMajor,
+ bfloat16_t,
+ layout::ColumnMajor,
+ float,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16, 8, 8>;
+
+ using ElementA = bfloat16_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = bfloat16_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = float;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ CUTLASS_HOST_DEVICE
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
+ FragmentC const &c) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ uint32_t const *C = reinterpret_cast(&c);
+ uint32_t *D = reinterpret_cast(&d);
+
+ asm(
+ "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
+ "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ :
+ "r"(A[0]), "r"(A[1]),
+ "r"(B[0]),
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
+ );
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 1684 - Float TF32
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
+template <>
+struct Mma<
+ gemm::GemmShape<16, 8, 4>,
+ 32,
+ tfloat32_t,
+ layout::RowMajor,
+ tfloat32_t,
+ layout::ColumnMajor,
+ float,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16, 8, 4>;
+
+ using ElementA = tfloat32_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = tfloat32_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = float;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ CUTLASS_HOST_DEVICE
+ void operator()(
+ FragmentC &d,
+ FragmentA const &a,
+ FragmentB const &b,
+ FragmentC const &c
+ ) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ float const *C = reinterpret_cast(&c);
+ float *D = reinterpret_cast(&d);
+
+ asm volatile(
+ "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
+ :
+ "r"(A[0]), "r"(A[1]),
+ "r"(B[0]),
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
+ );
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 1688 - Float TF32
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32
+template <>
+struct Mma, 32, tfloat32_t, layout::RowMajor,
+ tfloat32_t, layout::ColumnMajor, float, layout::RowMajor,
+ OpMultiplyAdd> {
+ using Shape = gemm::GemmShape<16, 8, 8>;
+
+ using ElementA = tfloat32_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = tfloat32_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = float;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ CUTLASS_HOST_DEVICE
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
+ FragmentC const &c) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ float const *C = reinterpret_cast(&c);
+ float *D = reinterpret_cast(&d);
+
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 "
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+ : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
+ "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 16816
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F16 = F16 * F16 + F16
+template <>
+struct Mma<
+ gemm::GemmShape<16, 8, 16>,
+ 32,
+ half_t,
+ layout::RowMajor,
+ half_t,
+ layout::ColumnMajor,
+ half_t,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16, 8, 16>;
+
+ using ElementA = half_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = half_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = half_t;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ /// Computes multiply-add
+ CUTLASS_HOST_DEVICE
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
+ FragmentC const &c) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ uint32_t const *C = reinterpret_cast(&c);
+ uint32_t *D = reinterpret_cast(&d);
+
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
+ : "=r"(D[0]), "=r"(D[1])
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
+ "r"(B[0]), "r"(B[1]),
+ "r"(C[0]), "r"(C[1])
+ );
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32
+template <>
+struct Mma<
+ gemm::GemmShape<16, 8, 16>,
+ 32,
+ bfloat16_t,
+ layout::RowMajor,
+ bfloat16_t,
+ layout::ColumnMajor,
+ float,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16, 8, 16>;
+
+ using ElementA = bfloat16_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = bfloat16_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = float;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ /// Computes multiply-add
+ CUTLASS_HOST_DEVICE
+ void operator()(
+ FragmentC &d,
+ FragmentA const &a,
+ FragmentB const &b,
+ FragmentC const &c
+ ) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ uint32_t const *C = reinterpret_cast(&c);
+ uint32_t *D = reinterpret_cast(&d);
+
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F32 = F16 * F16 + F32
+template <>
+struct Mma<
+ gemm::GemmShape<16, 8, 16>,
+ 32,
+ half_t,
+ layout::RowMajor,
+ half_t,
+ layout::ColumnMajor,
+ float,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16, 8, 16>;
+
+ using ElementA = half_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = half_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = float;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+ using ArchTag = arch::Sm80;
+
+ /// Computes multiply-add
+ CUTLASS_HOST_DEVICE
+ void operator()(
+ FragmentC &d,
+ FragmentA const &a,
+ FragmentB const &b,
+ FragmentC const &c
+ ) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const *B = reinterpret_cast(&b);
+ uint32_t const *C = reinterpret_cast(&c);
+ uint32_t *D = reinterpret_cast(&d);
+
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, "
+ "{%10,%11,%12,%13};\n"
+ : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]),
+ "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 884 - F64
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: F64 = F64 * F64 + F64
+template <>
+struct Mma<
+ gemm::GemmShape<8,8,4>,
+ 32,
+ double,
+ layout::RowMajor,
+ double,
+ layout::ColumnMajor,
+ double,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<8,8,4>;
+
+ using ElementA = double;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = double;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = double;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+
+ using ArchTag = arch::Sm80;
+
+ CUTLASS_HOST_DEVICE
+ void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
+ FragmentC const &c) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+
+ uint64_t const & A = reinterpret_cast(a);
+ uint64_t const & B = reinterpret_cast(b);
+
+ uint64_t const *C = reinterpret_cast(&c);
+ uint64_t *D = reinterpret_cast(&d);
+
+ asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
+ : "=l"(D[0]), "=l"(D[1])
+ : "l"(A), "l"(B), "l"(C[0]), "l"(C[1]));
+
+#else
+ assert(0);
+#endif
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Matrix Multiply 16816 - S8 input, S32 accumulation
+//
+////////////////////////////////////////////////////////////////////////////////
+
+/// Matrix multiply-add operation: S32 = S8 * S8 + S32
+template <>
+struct Mma<
+ gemm::GemmShape<16,8,16>,
+ 32,
+ int8_t,
+ layout::RowMajor,
+ int8_t,
+ layout::ColumnMajor,
+ int,
+ layout::RowMajor,
+ OpMultiplyAdd> {
+
+ using Shape = gemm::GemmShape<16,8,16>;
+
+ using ElementA = int8_t;
+ using LayoutA = layout::RowMajor;
+ using FragmentA = Array;
+
+ using ElementB = int8_t;
+ using LayoutB = layout::ColumnMajor;
+ using FragmentB = Array;
+
+ using ElementC = int;
+ using LayoutC = layout::RowMajor;
+ using FragmentC = Array;
+
+ using Operator = OpMultiplyAdd;
+
+ using ArchTag = arch::Sm80;
+
+ /// Computes multiply-add
+ CUTLASS_HOST_DEVICE
+ void operator()(
+ FragmentC &d,
+ FragmentA const &a,
+ FragmentB const &b,
+ FragmentC const &c
+ ) const {
+
+#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
+ uint32_t const *A = reinterpret_cast(&a);
+ uint32_t const &B = reinterpret_cast(b);
+
+ int const *C = reinterpret_cast