Compare commits

...

19 Commits

Author SHA1 Message Date
1ab1027954 Updated mma_sm80.h to avoid perf penalty due to reinterpret_cast<>. (#100)
- Updated mma_sm80.h to avoid perf penalty due to reinterpret_cast<>.
- Enhancement to CUTLASS Utility Library's HostTensorPlanarComplex template to support copy-in and copy-out
- Added test_examples target to build and test all CUTLASS examples
- Minor edits to documentation to point to GTC 2020 webinar
2020-06-15 10:47:01 -07:00
86931fef85 CUTLASS 2.2 (#96)
Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
2020-06-08 16:17:35 -07:00
e33d90b361 update tools/library/CMakeLists to require python 3.6 according to #70 (#82)
#70 only updates the documentation. This commit reflects this bump in python version to the CMake configuration as well.
2020-04-08 10:54:36 -07:00
96dab34ad9 CUTLASS 2.1 (#83)
CUTLASS 2.1 contributes:
- BLAS-style host-side API added to CUTLASS Library
- Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
- Minor enhancements and bug fixes
2020-04-07 13:51:25 -07:00
7c0cd26d13 Need Python 3.6 to use enum.auto() (#70) 2019-11-22 09:39:12 -08:00
45ecbc885b Removed redundant conjugation operations from matrix_traits. (#65) 2019-11-20 11:27:13 -08:00
8aca98f9a7 Improved formatting, clarity, and content of several documents. (#64)
* Improved formatting, clarity, and content of several documents.
2019-11-20 10:42:15 -08:00
f4d9c8f755 Clang GPU compilation requires explicit CUDACC version flags (#63) 2019-11-20 09:52:11 -08:00
fb335f6a5f CUTLASS 2.0 (#62)
CUTLASS 2.0

Substantially refactored for

- Better performance, particularly for native Turing Tensor Cores
- Robust and durable templates spanning the design space
- Encapsulated functionality embodying modern C++11 programming techniques
- Optimized containers and data types for efficient, generic, portable device code

Updates to:
- Quick start guide
- Documentation
- Utilities
- CUTLASS Profiler

Native Turing Tensor Cores
- Efficient GEMM kernels targeting Turing Tensor Cores
- Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands

Coverage of existing CUTLASS functionality:
- GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
- Volta Tensor Cores through native mma.sync and through WMMA API
- Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
- Batched GEMM operations
- Complex-valued GEMMs

Note: this commit and all that follow require a host compiler supporting C++11 or greater.
2019-11-19 16:55:34 -08:00
b5cab177a9 Performance enhancement for Volta Tensor Cores TN layout (#53)
* Fixed performance defect with indirect access to pointer array for Volta TensorCores TN arrangement.

* Updated patch version and changelog.

* Updated patch version and changelog.

* Added link to changelog in readme.

* Fixed markdown link
2019-07-10 10:54:12 -07:00
eb41735933 Merge pull request #47 from Artem-B/cutlass-1.3-clang
Make CUTLASS compileable with Clang.
2019-05-13 10:52:45 -07:00
fb8b3a98b7 Addressed code review comments. 2019-05-10 10:24:52 -07:00
d9d357877f Added missing file (#48) 2019-05-09 14:07:52 -07:00
e18292db46 Make CUTLASS compileable with Clang.
Requires a recent clang build (r359248 or newer).

Enable compilation with clang with these options:
cmake -DCUDA_COMPILER=clang -DCMAKE_CXX_COMPILER=/path/to/clang++
2019-05-02 11:00:22 -07:00
fe3438a3c1 cutlass 1.3.1 (#46)
CUTLASS 1.3.1 patch resolves failing text with NVRTC.
2019-04-19 16:54:52 -07:00
877bdcace6 Cutlass 1.3 Release (#42)
CUTLASS 1.3 Release
- Efficient GEMM kernel targeting Volta Tensor Cores via mma.sync instruction added in CUDA 10.1.
2019-03-20 10:49:17 -07:00
19a9d64e3c Removed patch version from README.
Removed patch version from README.
2018-12-19 15:20:43 -08:00
80e6f7c860 Merge pull request #38 from NVIDIA/resolve_maxwell
Resolved issue for incorrect SGEMM on Maxwell architecture.
2018-12-19 15:17:41 -08:00
822b0952cd Resolved issue for incorrect SGEMM on Maxwell architecture. 2018-12-19 15:07:16 -08:00
5542 changed files with 667195 additions and 234441 deletions

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "tools/external/googletest"]
path = tools/external/googletest
url = https://github.com/google/googletest.git

View File

@ -1,5 +1,65 @@
# NVIDIA CUTLASS Changelog
# CUTLASS 2.x
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* Fast Tensor Core operations:
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Tensor Float 32, BFloat16, and double-precision data types
* Mixed integer data types (int8, int4, bin1)
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
* Features:
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
* Gaussian complex GEMMs using 3m complex multiply algorithm
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
* Policy updates:
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
* Minor enhancements and bug fixes
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
* Substantially refactored for
* Better performance, particularly for native Turing Tensor Cores
* Robust and durable templates spanning the design space
* Encapsulated functionality embodying modern C++11 programming techniques
* Optimized containers and data types for efficient, generic, portable device code
* Updates to:
* [Quick start guide](/media/docs/quickstart.md)
* [Documentation](/README.md#documentation)
* [Utilities](/media/docs/utilities.md)
* [CUTLASS Profiler](/media/docs/profiler.md)
* Native Turing Tensor Cores
* Efficient GEMM kernels targeting Turing Tensor Cores
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
* Coverage of existing CUTLASS functionality
* GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
* Volta Tensor Cores through native mma.sync and through WMMA API
* Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
* Batched GEMM operations
* Complex-valued GEMMs
* **Note: a host compiler supporting C++11 or greater is required.**
# CUTLASS 1.x
## [1.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.2) (2019-07-09)
* Performance improvement for Volta Tensor Cores TN and TT layouts.
## [1.3.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.1) (2019-04-09)
* Corrected NVRTC unit tests.
## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20)
* Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1.
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
* Parallelized reductions across threadblocks ("Split-K")
* Improved IGEMM performance
@ -41,7 +101,7 @@
## Copyright
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted

464
CMakeLists.txt Normal file → Executable file
View File

@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
@ -20,40 +20,102 @@
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.3.0)
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
set(CUTLASS_LANGUAGES CXX)
# CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it!
if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0")
list(APPEND CUTLASS_LANGUAGES CUDA)
set(CUTLASS_NATIVE_CUDA TRUE)
macro(cutlass_add_executable)
add_executable(${ARGN})
endmacro()
if(cutlass_LOADED)
# If CUTLASS has been previously fetched and loaded, don't do it again.
return()
else()
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
if (WIN32 AND MSVC_VERSION GREATER 1800)
message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
endif()
# Fall back to the FindCUDA version to create an executable with CUDA files
macro(cutlass_add_executable)
cuda_add_executable(${ARGN})
endmacro()
set(cutlass_LOADED ON)
set(CUTLASS_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE PATH "CUTLASS Repository Directory")
endif()
project(CUTLASS ${CUTLASS_LANGUAGES})
message(STATUS "CMake Version: ${CMAKE_VERSION}")
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
find_package(Doxygen QUIET)
#
# CUTLASS 2.x requires C++11
#
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
if(CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_STANDARD 11)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
else()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
endif()
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
endif()
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
if(CUTLASS_ENABLE_HEADERS_ONLY)
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
else()
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
set(CUTLASS_ENABLE_TOOLS_INIT ON)
endif()
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
else()
set(CUTLASS_ENABLE_TESTS_INIT OFF)
endif()
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
if (CUTLASS_ENABLE_TESTS)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
endif()
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (NOT CUDA_VERSION VERSION_LESS 7.5)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
endif()
if (NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
endif()
if (NOT CUDA_VERSION VERSION_LESS 9.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70)
endif()
if (NOT CUDA_VERSION VERSION_LESS 9.2)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72)
endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
# Special policy introduced in CMake 3.13
if (POLICY CMP0076)
cmake_policy(SET CMP0076 NEW)
endif()
# check if the configuration is supported
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
endif()
find_package(CUDA)
find_package(Doxygen QUIET)
include(GNUInstallDirs)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
###################################################################################################
#
@ -61,42 +123,66 @@ find_package(Doxygen QUIET)
#
###################################################################################################
find_library(CUBLAS_LIBRARY cublas HINTS
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}")
# By default we want to build in Release mode to ensure that we're getting best performance
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
# By default we want to build in Release mode to ensure that we're getting best performance.
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
# We do support Debug or Release builds
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
endif()
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
if(WIN32)
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
endif()
if (WIN32)
# Enable more warnings and treat as errors
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
# Enable more warnings and treat as errors
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
# Disable warning on Unicode characters
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
# Disable warning on Unicode characters
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
# Verbose option
if (${CUTLASS_NVCC_VERBOSE})
string(APPEND NVCC_FLAGS " -v")
endif()
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict)
endif(WIN32)
set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
if (${CUTLASS_NVCC_VERBOSE})
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
endif()
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
#
# CUTLASS generator cmake configuration
#
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#
if (CUDA_VERSION VERSION_LESS 10.1)
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF)
else()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
endif()
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
"Enable PTX mma instruction for collective matrix multiply operations.")
#
# NOTE: running with asan and CUDA requires the following environment variable:
@ -111,7 +197,7 @@ set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.
# ...
#
if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer)
string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
endif()
@ -121,59 +207,136 @@ endif()
#
###################################################################################################
# Set NVCC arguments
foreach(ARCH ${CUTLASS_NVCC_ARCHS})
if(CUTLASS_NVCC_EMBED_CUBIN)
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
if(CUTLASS_NVCC_EMBED_PTX)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all)
endif()
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
endif()
if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1)
if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c)
elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC"))
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2)
endif()
if(CUTLASS_NVCC_EMBED_PTX)
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=compute_${ARCH}")
endif()
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
# Don't leak lineinfo in release builds
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
endif()
if(CUDA_COMPILER MATCHES "[Cc]lang")
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
endif()
endforeach()
if (CUTLASS_NVCC_KEEP)
string(APPEND NVCC_FLAGS " -keep")
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
endif()
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR)
list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR})
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
link_libraries(nvidia::cudart)
endif()
if (WIN32 AND CUTLASS_NATIVE_CUDA)
string(APPEND NVCC_FLAGS_RELEASE " -lineinfo")
else()
string(APPEND NVCC_FLAGS " -lineinfo")
endif()
function(cutlass_apply_cuda_gencode_flags TARGET)
if (UNIX)
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
endif()
set(NVCC_FLAGS)
set(CLANG_FLAGS)
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
set(CODES)
if(CUTLASS_NVCC_EMBED_CUBIN)
list(APPEND CODES sm_${ARCH})
endif()
if(CUTLASS_NVCC_EMBED_PTX)
list(APPEND CODES compute_${ARCH})
endif()
list(JOIN CODES "," CODES_STR)
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
endforeach()
string(APPEND NVCC_FLAGS_DEBUG " -g")
string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3")
string(APPEND NVCC_FLAGS_RELEASE " -O3")
if (CUDA_COMPILER MATCHES "[Cc]lang")
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
)
else()
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:${NVCC_FLAGS}>
)
endif()
# define NDEBUG for release mode to disable assertions
string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
endfunction()
if (CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}")
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
else()
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO})
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
endif()
function(cutlass_apply_standard_compile_options TARGET)
if(CUDA_COMPILER MATCHES "[Cc]lang")
set(CUDA_COMPILE_LANGUAGE CXX)
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS})
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
else()
set(CUDA_COMPILE_LANGUAGE CUDA)
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS})
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
endif()
target_compile_options(
${TARGET}
PRIVATE
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:${_FLAGS}>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELEASE>:${_FLAGS_RELEASE}>>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELWITHDEBINFO>:${_FLAGS_RELWITHDEBINFO}>>
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:DEBUG>:${_FLAGS_DEBUG}>>
)
endfunction()
#
# The following items should eventually be pushed into cutlass/CMakeLists.txt
#
# GLOB for CUTLASS header files. Should we use a static list instead?
file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )
file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h)
file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h)
###################################################################################################
#
@ -181,32 +344,71 @@ file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reducti
#
###################################################################################################
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
source_group("cutlass" FILES ${CUTLASS_CORE})
source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR}/include REGULAR_EXPRESSION ".*\.h")
add_library(CUTLASS INTERFACE)
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
target_sources(CUTLASS INTERFACE
${CUTLASS_GEMM}
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
)
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass)
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/)
# The following utility directory is needed even if the tools build is disabled, so it exists here.
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
include_directories(${CUTLASS_INCLUDE_DIR})
target_compile_features(CUTLASS INTERFACE cxx_std_11)
if (NOT DEFINED CUTLASS_REVISION)
find_package(Git QUIET)
execute_process(
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
RESULT_VARIABLE CUTLASS_REVISION_RESULT
OUTPUT_VARIABLE CUTLASS_REVISION
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if (CUTLASS_REVISION_RESULT)
message(STATUS "CUTLASS Revision: Unable to detect, Git returned code ${CUTLASS_REVISION_RESULT}.")
else()
message(STATUS "CUTLASS Revision: ${CUTLASS_REVISION}")
endif()
endif()
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
@ONLY)
target_include_directories(
CUTLASS
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
)
install(
DIRECTORY
${CUTLASS_INCLUDE_DIR}/
${CMAKE_CURRENT_BINARY_DIR}/include/
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
install(
TARGETS CUTLASS
EXPORT NvidiaCutlass
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
)
################################################################################
# Create a custom target to ensure that the CUTLASS sources are visible in an IDE
add_custom_target(cutlass_ide SOURCES
${CUTLASS_GEMM}
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
)
# Doxygen is available. Generate documentation
if (DOXYGEN_FOUND)
# DOT is available. Enable graph generation in the documentation
@ -232,5 +434,55 @@ if (DOXYGEN_FOUND)
)
endif()
add_subdirectory(tools)
add_subdirectory(examples)
if(NOT WIN32)
# Add common library search paths so executables and libraries can load and run
# without LD_LIBRARY_PATH being set.
link_libraries(
"-Wl,-rpath,'$ORIGIN'"
"-Wl,-rpath,'$ORIGIN/../lib64'"
"-Wl,-rpath,'$ORIGIN/../lib'"
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'"
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'"
)
endif()
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake)
if (CUTLASS_ENABLE_CUBLAS)
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUBLAS=1)
endif()
################################################################################
if(CUTLASS_ENABLE_TOOLS)
add_subdirectory(tools)
endif()
if(CUTLASS_ENABLE_EXAMPLES)
add_subdirectory(examples)
endif()
if(CUTLASS_ENABLE_TESTS)
include(CTest)
enable_testing()
add_subdirectory(test)
endif()
################################################################################
install(
FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/
)
install(
EXPORT NvidiaCutlass
NAMESPACE nvidia::cutlass::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/
FILE NvidiaCutlassTargets.cmake
)
################################################################################
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake)

57
CONTRIBUTORS.md Normal file
View File

@ -0,0 +1,57 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS")
[README](/README.md#documentation) > **Contributors**
# CUTLASS Developers and Contributors
This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Naila Farooqui
Piotr Majcher
Paul Springer
Jin Wang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner
## CONTRIBUTORS
Timothy Costa
Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Duane Merrill
Kevin Siu
Markus Tavenrath
John Tran
Vicki Wang
Junkai Wu
Fung Xie
Albert Xu
Jack Yang
Xiuxia Zhang
Nick Zhao
## ACKNOWLEDGEMENTS
Girish Bharambe
Cris Cecka
Luke Durant
Olivier Giroux
Stephen Jones
Rishkul Kulkarni
Bryce Lelbach
Joel McCormack
Kyrylo Perelygin

349
CUDA.cmake Normal file
View File

@ -0,0 +1,349 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice, this list of
# conditions and the following disclaimer in the documentation and/or other materials
# provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
if(CUDA_COMPILER MATCHES "[Cc]lang")
set(CUTLASS_NATIVE_CUDA_INIT ON)
elseif(CMAKE_VERSION VERSION_LESS 3.12.4)
set(CUTLASS_NATIVE_CUDA_INIT OFF)
else()
set(CUTLASS_NATIVE_CUDA_INIT ON)
endif()
set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow")
if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH})
# For backward compatibility, allow use of CUDA_PATH.
set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc)
endif()
if(CUTLASS_NATIVE_CUDA)
enable_language(CUDA)
if(NOT CUDA_VERSION)
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
endif()
if(NOT CUDA_TOOLKIT_ROOT_DIR)
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
endif()
else()
find_package(CUDA REQUIRED)
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
if(NOT CMAKE_CUDA_COMPILER_VERSION)
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
endif()
endif()
if (CUDA_VERSION VERSION_LESS 9.2)
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
endif()
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
endif()
find_library(
CUDART_LIBRARY cudart
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x64
lib64
lib
NO_DEFAULT_PATH
# We aren't going to search any system paths. We want to find the runtime
# in the CUDA toolkit we're building against.
)
if(NOT TARGET cudart AND CUDART_LIBRARY)
message(STATUS "CUDART: ${CUDART_LIBRARY}")
if(WIN32)
add_library(cudart STATIC IMPORTED GLOBAL)
# Even though we're linking against a .dll, in Windows you statically link against
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
# from the PATH search.
else()
add_library(cudart SHARED IMPORTED GLOBAL)
endif()
add_library(nvidia::cudart ALIAS cudart)
set_property(
TARGET cudart
PROPERTY IMPORTED_LOCATION
${CUDART_LIBRARY}
)
elseif(TARGET cudart)
message(STATUS "CUDART: Already Found")
else()
message(STATUS "CUDART: Not Found")
endif()
find_library(
CUDA_DRIVER_LIBRARY cuda
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x64
lib64
lib
lib64/stubs
lib/stubs
NO_DEFAULT_PATH
# We aren't going to search any system paths. We want to find the runtime
# in the CUDA toolkit we're building against.
)
if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}")
if(WIN32)
add_library(cuda_driver STATIC IMPORTED GLOBAL)
# Even though we're linking against a .dll, in Windows you statically link against
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
# from the PATH search.
else()
add_library(cuda_driver SHARED IMPORTED GLOBAL)
endif()
add_library(nvidia::cuda_driver ALIAS cuda_driver)
set_property(
TARGET cuda_driver
PROPERTY IMPORTED_LOCATION
${CUDA_DRIVER_LIBRARY}
)
elseif(TARGET cuda_driver)
message(STATUS "CUDA Driver: Already Found")
else()
message(STATUS "CUDA Driver: Not Found")
endif()
find_library(
NVRTC_LIBRARY nvrtc
PATHS
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib/x64
lib64
lib
NO_DEFAULT_PATH
# We aren't going to search any system paths. We want to find the runtime
# in the CUDA toolkit we're building against.
)
if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
message(STATUS "NVRTC: ${NVRTC_LIBRARY}")
if(WIN32)
add_library(nvrtc STATIC IMPORTED GLOBAL)
# Even though we're linking against a .dll, in Windows you statically link against
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
# from the PATH search.
else()
add_library(nvrtc SHARED IMPORTED GLOBAL)
endif()
add_library(nvidia::nvrtc ALIAS nvrtc)
set_property(
TARGET nvrtc
PROPERTY IMPORTED_LOCATION
${NVRTC_LIBRARY}
)
elseif(TARGET nvrtc)
message(STATUS "NVRTC: Already Found")
else()
message(STATUS "NVRTC: Not Found")
endif()
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
# paths by default, so we add it explicitly here.
function(cutlass_correct_source_file_language_property)
if(CUDA_COMPILER MATCHES "clang")
foreach(File ${ARGN})
if(File MATCHES ".*\.cu$")
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
endif()
endforeach()
endif()
endfunction()
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
function(cutlass_unify_source_files TARGET_ARGS_VAR)
set(options)
set(oneValueArgs BATCH_SOURCES BATCH_SIZE)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED TARGET_ARGS_VAR)
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
endif()
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
endif()
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
set(CUDA_FILE_ARGS)
set(TARGET_SOURCE_ARGS)
foreach(ARG ${__UNPARSED_ARGUMENTS})
if(${ARG} MATCHES ".*\.cu$")
list(APPEND CUDA_FILE_ARGS ${ARG})
else()
list(APPEND TARGET_SOURCE_ARGS ${ARG})
endif()
endforeach()
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
while(NUM_CUDA_FILE_ARGS GREATER 0)
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}")
string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH)
set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu)
message(STATUS "Generating ${BATCH_FILE}")
file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n")
foreach(CUDA_FILE ${CUDA_FILE_BATCH})
get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE)
file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n")
endforeach()
list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE})
if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE)
break()
endif()
list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS)
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
endwhile()
else()
set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
endif()
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
endfunction()
function(cutlass_add_library NAME)
set(options)
set(oneValueArgs EXPORT_NAME)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_library(${NAME} ${TARGET_SOURCE_ARGS})
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
endif()
cutlass_apply_standard_compile_options(${NAME})
cutlass_apply_cuda_gencode_flags(${NAME})
target_compile_features(
${NAME}
INTERFACE
cxx_std_11
)
if(__EXPORT_NAME)
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
endif()
endfunction()
function(cutlass_add_executable NAME)
set(options)
set(oneValueArgs)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
else()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
endif()
cutlass_apply_standard_compile_options(${NAME})
cutlass_apply_cuda_gencode_flags(${NAME})
target_compile_features(
${NAME}
INTERFACE
cxx_std_11
)
endfunction()
function(cutlass_target_sources NAME)
set(options)
set(oneValueArgs)
set(multiValueArgs)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
endfunction()

View File

@ -1,378 +0,0 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS
This document is intended to accompany the CUTLASS source code, to describe the interaction between
CUTLASS core components, and to identify their role in implementing GEMM computations efficiently in CUDA.
1. [Design Patterns](#S-design-patterns)
2. [General Matrix Multiply](#S-general-matrix-multiply)
3. [Core Components](#S-core-components)
4. [Utilities](#S-utilities)
5. [Optimization Strategies](#S-optimization-strategies)
# <a name="S-design-patterns"></a> 1. Design Patterns
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
flexible composition that an be easily applied to solve new problems related to Deep Learning and
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
design patterns are necessary to yield a composable structure while also satisfying these performance
objectives. This section is intended to provide more detail.
* [Sequencing and Nesting](#S-patterns-sequencing-nesting)
* [Tiles and Iterators](#S-patterns-tiles-iterators)
* [Host-side Params](#S-patterns-host-side-params)
* [Composable Shared Memory](#S-patterns-composable-shared-memory)
## <a name="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
elements in memory as well as traversing over a collection. GEMM kernels in CUTLASS depend on accessing
a sequence of tiles from global memory, from shared memory, and in registers. Consequently, _tile iterators_
are prevalent throughout the CUTLASS implementation.
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
## <a name="S-patterns-host-side-params"></a> Host-side Params structure
Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer.
CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread.
The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
## <a name="S-patterns-composable-shared-memory"></a> Composable shared memory allocation
Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held in shared memory. Any object requiring shared memory storage for itself or its data members should define a child structure called SharedStorage. This holds data needed by the class and also instantiates SharedStorage objects for each data member.
To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall SMEM capacity.
## <a name="S-patterns-loop-unrolling"></a> Loop Unrolling
CUTLASS requires tiles of data to be stored in registers for high-bandwidth access. Simultaneously, high-throughput math instructions
must be issued concurrently with memory instructions to hide latency with relatively few concurrent threads. These objectives are
achieved by unrolling loops whose iteration counts are known at compile time.
Consequently, most loops within the CUTLASS GEMM implementation are specified by constant values and template arguments. The CUDA compiler
is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule.
## <a name="S-patterns-loop-unrolling"></a> Templates
CUDA C++ templates and modern generic programming techniques enable CUTLASS device code to span a large design space.
This design space includes:
* Mixed precision arithmetic and data storage
* Kernels specialized for layout and problem size
* Support for kernel fusion
Moreover, templates provided a structured approach to collecting compile-time constants such as tile dimensions. These
must be template arguments to target static array allocation and take advantage of loop unrolling, constant folding,
and function inlining.
# <a name="S-general-matrix-multiply"></a> 2. General Matrix Multiply
The following figure illustrates the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right.
![ALT](/media/images/gemm-structural-components.png "CUTLASS GEMM Structural Components")
## Threadblock-level GEMM
The CUTLASS GEMM kernel partitions the _C_ matrix into a 2D tiling of threadblocks.
Each threadblock computes a matrix product whose outer dimensions _M_ and _N_ are compile-time constants. The
GEMM's _K_ dimension is partitioned into tiles and iterated over by the GEMM _mainloop_. The shape of the matrix
multiply operation performed by each iteration of the mainloop is referred to as _OutputTile_.
The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
buffer in shared memory is performed by a _GlobalLoadIterator_.
**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
The Global Load Stream template contains members defined by the following templates:
* [GemmGlobalIteratorAb](cutlass/gemm/gemm_global_tile.h)
* [Transformer](cutlass/convert.h)
* [GemmSharedStoreTileAb](cutlass/gemm/gemm_shared_tile.h)
## Warp-level GEMM
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)
## Thread-level GEMM
SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
procedures.
* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)
## Epilogue
The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
1. [Transformer](cutlass/convert.h) for converting the data types of accumulator elements
2. [GemmSharedStoreTileD](cutlass/gemm/gemm_shared_tile.h) to store to shared memory specialized to the accumulator layout.
3. [GemmSharedLoadTileD](cutlass/gemm/gemm_shared_tile.h) to load the data from shared memory.
4. [GemmGlobalIteratorC](cutlass/gemm/gemm_global_tile.h) to load a tile from global memory.
5. A [functor](cutlass/gemm/linear_scaling.h) to compute an element-wise operation on the matrix product and source data (such as alpha*AB+beta*C).
6. [GemmGlobalIteratorD](cutlass/gemm/gemm_global_tile.h) to write the output to global memory.
## GEMM Traits
[**cutlass::gemm::GemmTraits**](cutlass/gemm/gemm_traits.h) collects the structural properties of a complete GEMM computation into a single template class. As a result, the Traits classes encapsulate the the iterators and transformers for all supported GEMM operands and layouts. Low-level details needed by Traits (such as scalar types for operands, thread-block tile size, number of scalar elements per memory access within each phase, number of stages in shared memory, as well as other implementation-specific properties of the GEMM computation) are specified in class [**cutlass::gemm::GemmConfig**](cutlass/gemm/gemm_config.h).
# <a name="S-core-components"></a> 3. Core Components
CUTLASS GEMM kernels are implemented by a set of Core components for interacting with mathematical tensor and matrix
objects as well as constructing efficient CUDA kernels.
* [Tensor views](#S-core-tensor-views)
* [Shape](#S-core-shape)
* [Tile structure](#S-core-tile-structure)
* [Fragment](#S-core-fragment)
* [Predicate vector](#S-core-predicate-vector)
## <a name="S-core-tensor-views"></a> Tensor View
Matrices and tensors are typically represented as n-D arrays held in linear memory with a single base pointer and a stride vector. Element _i_ of the stride vector indicates the offset in linear memory between consecutive elements in dimension i. Consequently, the linear offset for an arbitrary element specified as an n-tuple may be computed as the dot product of the coordinate and the stride vector.
CUTLASS provides abstractions for interacting with multidimension tensors in device memory.
Consequently, we define a hierarchy of pointer-like types for referencing tensors.
`T *` - raw pointer to elements of type T
`cutlass::TensorRef<T, Rank>` - reference to a tensor of elements of type T and given rank. Includes a mapping function and associated stride vector for accessing elements in linear memory.
`cutlass::TensorView<T, Rank>` - extends `TensorRef<>` by adding bounds information. This is a complete mathematical object which may be used as the argument to CUTLASS functions.
The above provide an identity maping of a logical index space to linear memory. An element
at logical coordinate X has an offset computed as follows:
```
offset = dot(X, stride)
```
where `dot()` computes the inner product of X and a vector of "strides."
CUTLASS 1.1 introduces a mapping function and an additional "storage rank" to offer a flexible way to
map the logical index space of the tensor to memory. The mapping function maps a coordinate
of rank _R_ to an index space of rank _S_. The linear offset is computed as:
```
offset = dot( MapFunc(X), stride )
```
where stride is a vector of rank _S_.
CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and
correctness. Consequently, we enforce a constraint on the strides used by mapping functions
such that:
1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in
that rank are consecutive in linear memory.
2. The fastest changing rank is always last in the stride vector and not explicitly stored.
Thus, the stride vector used by mapping functions has length of one fewer than the rank of the
storage tensor. These constraints are consistent with the BLAS interface of passing matrices as
a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors
whose fastest changing dimension is 1, and only the strided dimension is explicitly represented.
A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor,
to linear memory such that (1.) elements in the same column are consecutive in memory
(column-major), or (2.) elements in the same row are consecutive (row-major). These can be
accomplished by two different mapping functions whose stride vector is length=2. The first
element is the "leading dimension."
The requirement that the fastest-changing stride always be of unit size need not be a limitation.
To implement "sparse" computations or matrix operations in which matrix elements have arbitrary
stride along both row and column, define a mapping function whose storage rank is 3. This permits
two elements of the stride vector to have a non-unit value.
`cutlass::TensorView<>` extends this concept by including a size vector to specify the bounds of
the index space. The value of each coordinate in the size vector defines the half-open range of
indices whose smallest value is zero.
## <a name="S-core-shape"></a> Shape
To avoid complicated template metaprogramming, CUTLASS targets fixed compile-time tile sizes specified
by a four-dimensional template `cutlass::Shape<>`. This defines the following dimensions, mirroring
the NHWC tensor format used for convolution in Deep Learning frameworks.
- `D`: depth of tensor
- `H`: first strided dimension
- `W`: contiguous sequence of tensor elements
- `C`: number of channels, usually used for vectorized access
Template specializations of `Shape` appear as arguments to numerous dependent template classes which
must specify compile-time constant tile sizes.
## <a name="S-core-tile-structure"></a> Tile Structure
Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
threads to the problem space. For example, the CUTLASS GEMM
Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
lattice, partition its elements among a collection of threads, and map each unique thread ID to a unique
offset.
- _Tile_ (concept `Shape<>`) - describes the dimensions of the tile in terms of scalar elements
- _Delta_ (concept `Shape<>`) - describes the distance along each logical dimension between items
- _Iterations_ (concept `Shape<>`) - describes the number of items along each logical dimension
- _ThreadOffset_ (concept _functor_) - implements `Coord<4> operator()() const` to determine a thread's
initial offset in the logical 4-D coordinate space
The following figure illustrates the CUTLASS tile structure. The overall shape, 16-by-16, is partitioned into
vectors of length two among 32 threads. The elements stored by thread 9 are highlighted.
<img src="/media/images/cutlass-tile-structure.png" alt="CUTLASS tile structure" width="30%" />
The `cutlass::TileTraits<>` definition that describes this arrangement may be defined as follows:
```
struct ExampleTileTraits {
/// Overall shape of tile
typedef Shape<1, 16, 16, 1> Tile;
/// Distance along each dimension of accesses
typedef Shape<1, 4, 1, 1> Delta;
/// Number of memory accesses performed by each thread
typedef Shape<1, 4, 1, 1> Iterations;
/// Offset function - maps each thread to a unique starting offset within the 4D tile
struct ThreadOffset {
CUTLASS_DEVICE Coord<4> operator()() const {
typdef Shape<1, 16, 8, 2> Vectorized;
return make_Coord(
0, // depth "D" dimension
threadIdx.x / Vectorized::kW, // horisontal "H" dimension - first strided dimension
threadIdx.x % Vectorized::kW, // vertical "W" dimension - contiguous dimension
0
);
}
};
};
```
## <a name="S-core-tile-iterator"></a> Tile Iterator
The iterator design pattern provides an abstraction for accessing the items in a collection in sequence. Basic
operators defined by iterators consist of accessing an item - either a load or store - followed by traversal to
the next item in sequence.
<img src="/media/images/cutlass-tile-iteration.png" alt="CUTLASS tile access and traversal" width="50%" />
To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
## <a name="S-core-fragment"></a> Fragment
A fragment is analogous to `std::array<>` in that it is a constant-sized array of elements. Typically backed by storage in the SM's register file, CUTLASS `Fragment<>` objects are used to store tiles. For threadblock- and warp-scope operations, the contents of these tiles are distributed across the partipcipating threads. In such cases, a thread's `Fragment<>` contains the part of the tile held by that thread.
## <a name="S-core-predicate-vector"></a> Predicate Vector
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
# <a name="S-utilities"></a> 4. Utilities
CUTLASS implements efficient matrix multiply computations on GPUs. It is accompanied by an extensive utility
framework offering features such as:
* [cutlass::half_t](tools/util/half.h) - a host-side half-precision type
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
# <a name="S-optimization-strategies"></a>5. Optimization Strategies
This section describes several strategies taken to increase performance beyond what is achievable with
a basic implementation of the hierarchical GEMM structure.
## Threadblock Rasterization
To maximize reuse of data held in the last level cache, CUTLASS defines several functions to
affect the mapping of threadblocks to logical partitions of the GEMM problem. These map
consecutively launched threadblocks to packed two-dimensional regions of the partitioned GEMM
problem to increase the probability that these will access the same tiles of global memory at
approximately the same time.
Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](cutlass/gemm/threadblock_swizzle.h).
## Parallel Reductions across GEMM _K_
Matrix product computations expose parallelism among _O(MN)_ independent inner product
computations. For sufficiently large problem sizes, a GEMM kernel in CUTLASS may approach
the theoretical maximum computational throughput. For small problems, however, there are
too few threadblocks to efficiently occupy the entire GPU.
As a recourse, parallelizing the reduction performed during the inner product computation
enables more threadblocks to execute concurrently while still taking advantage of the throughput
benefits of large threadblock-level GEMM tiles.
CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
and launching an additional set of threadblocks for each partition. Consequently, we refer to
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called batched reduction.
The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the number of partition that will be applied along K dimension for operand A and B. For example, parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.
The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by the users to store this intermediate results.
An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_gemm.cu).
# Copyright
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted
provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other materials
provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
to endorse or promote products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

View File

@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places.
# The default value is: My Project.
PROJECT_NAME = "Cutlass"
PROJECT_NAME = "CUTLASS"
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version
@ -51,7 +51,7 @@ PROJECT_BRIEF = "CUDA Templates for Linear Algebra Subroutines and Solv
# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo
# to the output directory.
PROJECT_LOGO =
PROJECT_LOGO = media/images/cutlass-logo-small.png
# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
# into which the generated documentation will be written. If a relative path is
@ -206,7 +206,7 @@ SEPARATE_MEMBER_PAGES = NO
# uses this value to replace tabs by spaces in code fragments.
# Minimum value: 1, maximum value: 16, default value: 4.
TAB_SIZE = 4
TAB_SIZE = 2
# This tag can be used to specify a number of aliases that act as commands in
# the documentation. An alias has the form:
@ -297,7 +297,7 @@ AUTOLINK_SUPPORT = YES
# diagrams that involve STL classes more complete and accurate.
# The default value is: NO.
BUILTIN_STL_SUPPORT = NO
BUILTIN_STL_SUPPORT = YES
# If you use Microsoft's C++/CLI language, you should set this option to YES to
# enable parsing support.
@ -734,7 +734,9 @@ WARN_LOGFILE =
# spaces.
# Note: If this tag is empty the current directory is searched.
INPUT = cutlass
INPUT = include/cutlass tools/util/include/cutlass/ tools/library/include/cutlass/
INPUT += media/docs/doxygen_mainpage.md
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
@ -870,7 +872,7 @@ FILTER_SOURCE_PATTERNS =
# (index.html). This can be useful if you have a project on for instance GitHub
# and want to reuse the introduction page also for the doxygen output.
USE_MDFILE_AS_MAINPAGE =
USE_MDFILE_AS_MAINPAGE = media/docs/doxygen_mainpage.md
#---------------------------------------------------------------------------
# Configuration options related to source browsing
@ -999,7 +1001,7 @@ GENERATE_HTML = YES
# The default directory is: html.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_OUTPUT = generated-html
HTML_OUTPUT =
# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
# generated HTML page (for example: .htm, .php, .asp).
@ -1080,7 +1082,7 @@ HTML_EXTRA_FILES =
# Minimum value: 0, maximum value: 359, default value: 220.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_COLORSTYLE_HUE = 82
HTML_COLORSTYLE_HUE = 100
# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
# in the HTML output. For a value of 0 the output will use grayscales only. A
@ -1088,7 +1090,7 @@ HTML_COLORSTYLE_HUE = 82
# Minimum value: 0, maximum value: 255, default value: 100.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_COLORSTYLE_SAT = 100
HTML_COLORSTYLE_SAT = 50
# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
# luminance component of the colors in the HTML output. Values below 100
@ -1107,7 +1109,7 @@ HTML_COLORSTYLE_GAMMA = 80
# The default value is: YES.
# This tag requires that the tag GENERATE_HTML is set to YES.
HTML_TIMESTAMP = YES
HTML_TIMESTAMP = NO
# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
# documentation will contain sections that can be hidden and shown after the

View File

@ -1,23 +0,0 @@
Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23
LICENSE.txt Normal file
View File

@ -0,0 +1,23 @@
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

353
README.md
View File

@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")
# CUTLASS 1.2
# CUTLASS 2.2
_CUTLASS 1.2.0 - October 2018_
_CUTLASS 2.2 - June 2020_
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
@ -16,40 +16,46 @@ and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for 8-bit integer, half-precision floating
point (FP16), single-precision floating point (FP32), and double-precision floating
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
and beyond.
multiply-accumulate abstractions for half-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32), double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
We describe the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, and Ampere architectures.
# What's New in CUTLASS 1.2
_October 2018_
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
* Batched strided WMMA GEMM
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.
# What's New in CUTLASS 1.1
_September 2018_
# What's New in CUTLASS 2.2
* [CUTLASS Documentation](CUTLASS.md)
* [Examples](examples/)
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
* Turing Features
* [WMMA GEMM targeting TensorCores](tools/test/unit/gemm/wmma_integer_gemm.cu) - INT8, INT4, 1-bit
* [Batched Strided GEMM](tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu)
* [Threadblock rasterization strategies](tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu)
* Improved performance for adverse problem sizes and data layouts
* Extended CUTLASS Core components
* Tensor views support arbitrary matrix and tensor layouts
* Zip iterators for structuring multiple data streams
* Enhanced CUTLASS utilities
* [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
* Added `HostMatrix<>` for simplified matrix creation
CUTLASS 2.2 is a significant update to CUTLASS adding:
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
- Deep software pipelines using asynchronous copy
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
# What's New in CUTLASS 2.1
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
# What's New in CUTLASS 2.0
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores
- Robust and durable templates that reliably span the design space
- Encapsulated functionality that may be reusable in other contexts
**See the [CHANGELOG](CHANGELOG.md) for more details.**
# Performance
@ -58,13 +64,15 @@ _September 2018_
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
when compiled with CUDA 10.0.
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
# Compatibility
CUTLASS performs best when compiled with the [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.0, 9.1, and 9.2, but these versions of the CUDA Toolkit do not support new Turing WMMA features.
CUTLASS requires a C++11 host compiler and
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
We have tested the following environments.
@ -72,60 +80,79 @@ We have tested the following environments.
|-----------------|----------|
| Windows 10 | Microsoft Visual Studio 2015|
| | Microsoft Visual Studio 2017|
| Ubuntu 14.04 | GCC 4.8.2 |
| Ubuntu 16.04 | GCC 5.4.0 |
| Ubuntu 18.04 | GCC 7.3.0 |
| Ubuntu 18.04 | GCC 7.5.0 |
Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|**GPU**|
|---|
|NVIDIA GeForce 1080|
|NVIDIA TitanXP|
|NVIDIA Tesla P100|
|NVIDIA Tesla V100|
|NVIDIA TitanV|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|---|
|NVIDIA Tesla P100|6.0|9.2| |
|NVIDIA GeForce 1080|6.1|9.2| |
|NVIDIA TitanXP|6.1|9.2| |
|NVIDIA Tesla V100|7.0|9.2|10.1|
|NVIDIA TitanV|7.0|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|NVIDIA Tesla T4|7.5|10.0|10.2|
|NVIDIA A100|8.0|11.0|11.0|
# Documentation
CUTLASS 2.2 is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
- [Terminology](media/docs/terminology.md) - describes terms used in the code
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory
- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
We have also described the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
# Building CUTLASS
CUTLASS is a header-only template library and does not need to be built to be used by other
projects. However, we distribute extensive unit tests and utility programs to demonstrate
CUTLASS. These instructions are for building those test programs.
projects. Client applications should target CUTLASS's `include/` directory in their include
paths.
CUTLASS's unit tests depend on Google Test which exists as a git submodule. You can fetch
submodules as follows.
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
on your system.
```
$ git submodule update --init --recursive
$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
Create a build directory within the CUTLASS project, then run CMake once.
```
$ mkdir build && cd build
$ cmake ..
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture
```
Compile the CUTLASS project by running Make. Include the -j argument to compile sources in
parallel and speed up the build process.
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS,
and they may be executed in parallel via make's `-j` command line argument.
```
$ make -j12
...
$
```
Verify CUTLASS has been built correctly by running the unit tests from the build/ directory.
```
$ ./tools/test/unit/cutlass_unit_test
$ make test_unit -j
...
...
...
@ -134,103 +161,159 @@ $ ./tools/test/unit/cutlass_unit_test
[ PASSED ] 946 tests.
```
All tests should pass, though the exact number of tests may vary over time.
All tests should pass on supported platforms, though the exact number of tests may vary over time.
# Project Structure
CUTLASS is arranged as a header-only library with several example test programs
that demonstrate instantiating a GEMM task within a CUDA kernel. The Doxygen documentation
provides a complete list of files, classes, and template concepts defined in the CUTLASS
project. A brief summary is described below.
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
and template concepts defined in the CUTLASS project.
The CUTLASS library is defined in the cutlass/ directory and consists of CUDA C++ template
classes and other definitions for implementing efficient GPU GEMM kernels. A set of core
classes and templates define basic primitives that are then applied to compute GEMM via
templates in the cutlass/gemm directory.
A detailed explanation of the source code organization may be found in the
[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below.
## CUTLASS Template Library
```
cutlass/
gemm/
util/
<core API components>
include/ # client applications should target this directory in their build's include paths
cutlass/ # CUDA Templates for Linear Algebra Subroutines and Solvers - headers only
arch/ # direct exposure of architecture features (including instruction-level GEMMs)
gemm/ # code specialized for general matrix product computations
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
platform/ # CUDA-capable Standard Library components
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
transform/ # code specialized for layout, type, and domain transformations
* # core vocabulary types, containers, and basic numeric operations
```
Several tools and test programs are also distributed with the CUTLASS library. They are
contained in the following directories.
### CUTLASS SDK Examples
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
```
examples/
00_basic_gemm/
01_tensor_view/
02_cutlass_utilities/
03_batched_gemm/
04_tile_iterator/
05_wmma_gemm/
tools/
test/
unit/
core/
gemm/
perf/
util/
reference/
device/
host/
<utilities>
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
```
### Tools
```
tools/
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
include/
cutlass/
library/
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
# CUTLASS Library
util/ # CUTLASS Utilities - contains numerous helper classes for
include/ # manging tensors in device memory, reference
cutlass/ # implementations for GEMM, random initialization
util/ # of tensors, and I/O.
```
### Test
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
The `tools/util` directory contains CUTLASS utilities including reference implementations of GEMM and
several element-wise tensor operations.
Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md).
# Performance Profiling
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
Its usage is shown below.
Program usage:
The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels.
It can be built as follows:
```
cutlass_perf_test [options]
--help
--append=<true|false*> If true, appends output to existing CSV file. If false, overwrites.
--alpha=<alpha> Value for alpha to be used in GEMM experiments
--beta=<beta> Value for beta to be used in GEMM experiments
--dist=<distribution> Describes the random distribution of each of the input matrix operands.
--execution_mode=<mode> Specifies execution mode: profile, verify, single
--output=<filename.csv> Writes summary of profiling to specified .csv file
--iterations=<timing iterations> maximum number of iterations to execute when profiling
--m=<height>[:max height[:step]] Height of GEMM problem (number of rows of C). May specify a range with optional step size.
--n=<width>[:max width[:step]] Width of GEMM problem (number of columns of C). May specify a range with optional step size.
--k=<depth>[:max depth[:step]] Size of inner dimension of A and B. May specify a range with optional step size.
--kernels=<{s|d|h|i|wmma_}gemm_{nn,nt,tn,tt}> Select GEMM datatype and layout to use for tests
--peak=<bool> If true, only reports peak performance per kernel after profiling specified problem space.
--save_workspace={*never,incorrect,always} Specifies when to save the GEMM inputs and results to the filesystem.
--seed=<seed> Random seed used by the random number generator in initializing input matrices.
--tags=<column:tag,...> Inserts leading columns in output table and uniform values for each column.
Example usage:
# Runs one problem size for all kernels
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
$ make cutlass_profiler -j
```
To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout.
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
```
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
...
$ make cutlass_profiler -j
```
Example command line for profiling SGEMM kernels is as follows:
```
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
=============================
Problem ID: 1
Provider: CUTLASS
OperationKind: gemm
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
Status: Success
Verification: ON
Disposition: Passed
cuBLAS: Passed
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
Bytes: 180355072 bytes
FLOPs: 115992428544 flops
Runtime: 6.73655 ms
Memory: 24.934 GiB/s
Math: 17218.4 GFLOP/s
```
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
# About
CUTLASS is released by NVIDIA Corporation as Open Source software under the
3-clause "New" BSD license.
CUTLASS is released by NVIDIA Corporation as Open Source software under the
[3-clause "New" BSD license](LICENSE.txt).
# Contributors
The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md).
# Copyright
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
```
Redistribution and use in source and binary forms, with or without modification, are permitted

View File

@ -13,8 +13,8 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
endfunction()
message("Create header file for ${FILE_IN}")
message("Create header file for ${FILE_OUT}")
# message("Create header file for ${FILE_IN}")
# message("Create header file for ${FILE_OUT}")
file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
set(RESULT "#pragma once\n")

View File

@ -0,0 +1,7 @@
get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
include(CMakeFindDependencyMacro)
if(NOT TARGET nvidia::cutlass::CUTLASS)
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
endif()

View File

@ -0,0 +1,14 @@
set(CPACK_PACKAGE_NAME NvidiaCutlass)
set(CPACK_PACKAGE_VENDOR NVIDIA)
set(CPACK_PACKAGE_CONTACT info@nvidia.com)
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CUTLASS CUDA C++ Template Linear Algebra Library")
set(CPACK_PACKAGE_INSTALL_DIRECTORY ${CPACK_PACKAGE_NAME})
set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR})
set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR})
set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH})
set(CPACK_VERBATIM_VARIABLES YES)
# set(CPACK_PACKAGE_DESCRIPTION_FILE ${CMAKE_CURRENT_LIST_DIR}/Description.txt)
# set(CPACK_RESOURCE_FILE_WELCOME ${CMAKE_CURRENT_LIST_DIR}/Welcome.txt)
# set(CPACK_RESOURCE_FILE_LICENSE ${CMAKE_CURRENT_LIST_DIR}/License.txt)
# set(CPACK_RESOURCE_FILE_README ${CMAKE_CURRENT_LIST_DIR}/Readme.txt)
include(CPack)

23
cmake/googletest.cmake Normal file
View File

@ -0,0 +1,23 @@
include(FetchContent)
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
if(GOOGLETEST_DIR)
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
endif()
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG 0fe9660
)
FetchContent_GetProperties(googletest)
if(NOT googletest_POPULATED)
FetchContent_Populate(googletest)
if (MSVC)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
endif()
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()

43
cmake/nop.cu Normal file
View File

@ -0,0 +1,43 @@
/***************************************************************************************************
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Basic CUDA file for testing compiler flags.
*/
__device__ int inner()
{
return -1;
}
__global__ void test()
{
inner();
}
int main()
{
test<<<1,1>>>();
return 0;
}

38
cmake/version.h.in Normal file
View File

@ -0,0 +1,38 @@
#include <cstdint>
#include <string>
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
namespace cutlass {
inline uint32_t getVersion() {
return CUTLASS_VERSION;
}
inline uint32_t getVersionMajor() {
return CUTLASS_MAJOR;
}
inline uint32_t getVersionMinor() {
return CUTLASS_MINOR;
}
inline uint32_t getVersionPatch() {
return CUTLASS_PATCH;
}
inline uint32_t getVersionBuild() {
return CUTLASS_BUILD + 0;
}
inline std::string getVersionString() {
std::string version = "@CUTLASS_VERSION@";
if (getVersionBuild()) {
version += "." + std::to_string(getVersionBuild());
}
return version;
}
inline std::string getGitRevision() {
return "@CUTLASS_REVISION@";
}
} // namespace cutlass

125
cuBLAS.cmake Normal file
View File

@ -0,0 +1,125 @@
message(STATUS "Configuring cublas ...")
if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
(DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED))
# Don't add cuBLAS if it's defined and false, assume it's not found.
set(CUBLAS_FOUND OFF)
message(STATUS "cuBLAS Disabled.")
elseif(NOT TARGET cublas)
find_path(
_CUBLAS_INCLUDE_DIR
NAMES cublas.h
HINTS
${CUBLAS_INCLUDE_PATH}
ENV CUBLAS_INCLUDE_PATH
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
include
)
find_library(
_CUBLAS_LIBRARY
NAMES cublas
HINTS
${CUBLAS_LIBRARY_PATH}
ENV CUBLAS_LIBRARY_PATH
${_CUBLAS_INCLUDE_DIR}/..
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib64
lib/x64
lib
)
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
message(STATUS "cuBLAS: ${_CUBLAS_LIBRARY}")
message(STATUS "cuBLAS: ${_CUBLAS_INCLUDE_DIR}")
set(CUBLAS_FOUND ON CACHE INTERNAL "cublas Library Found")
set(CUBLAS_LIBRARY ${_CUBLAS_LIBRARY})
set(CUBLAS_INCLUDE_DIR ${_CUBLAS_INCLUDE_DIR})
else()
message(STATUS "cublas not found.")
set(CUBLAS_FOUND OFF CACHE INTERNAL "cublas Library Found")
endif()
endif()
set(CUTLASS_ENABLE_CUBLAS ${CUBLAS_FOUND} CACHE BOOL "Enable CUTLASS to build with cuBLAS library.")
if(CUTLASS_ENABLE_CUBLAS AND NOT CUBLAS_FOUND)
message(FATAL_ERROR "CUTLASS_ENABLE_CUBLAS enabled but cuBLAS library could not be found.")
endif()
if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
if(WIN32)
add_library(cublas STATIC IMPORTED GLOBAL)
else()
add_library(cublas SHARED IMPORTED GLOBAL)
endif()
add_library(nvidia::cublas ALIAS cublas)
set_property(
TARGET cublas
PROPERTY IMPORTED_LOCATION
${CUBLAS_LIBRARY})
target_include_directories(
cublas
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUBLAS_INCLUDE_DIR}>)
find_library(
_CUBLASLT_LIBRARY
NAMES cublasLt
HINTS
${CUBLAS_LIBRARY_PATH}
ENV CUBLAS_LIBRARY_PATH
${_CUBLAS_INCLUDE_DIR}/..
${CUBLAS_PATH}
ENV CUBLAS_PATH
${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES
lib64
lib/x64
lib
)
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
if(WIN32)
add_library(cublasLt STATIC IMPORTED GLOBAL)
else()
add_library(cublasLt SHARED IMPORTED GLOBAL)
endif()
set_property(
TARGET cublasLt
PROPERTY IMPORTED_LOCATION
${_CUBLASLT_LIBRARY})
add_library(nvidia::cublasLt ALIAS cublasLt)
target_link_libraries(cublas INTERFACE cublasLt)
endif()
endif()
message(STATUS "Configuring cuBLAS ... done.")

View File

@ -1,102 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines conversion operations among Fragments of different base type.
*/
#pragma once
#include "cutlass/fragment.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputFragment_, typename OutputFragment_>
struct Convert {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputScalar_, typename OutputScalar_, int kScalars_>
struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
/// The input fragment.
typedef Fragment<InputScalar_, kScalars_> InputFragment;
/// The output fragment.
typedef Fragment<OutputScalar_, kScalars_> OutputFragment;
/// Ctor.
CUTLASS_DEVICE Convert() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
transform(src, 0, dst);
}
/// Transform a fragment.
template <typename Fragment_>
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
for (int i = 0; i < kScalars_; ++i) {
dst[i] = static_cast<OutputScalar_>(src[i + offset]);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Fragment_>
struct Copy {
/// The input fragment.
typedef Fragment_ InputFragment;
/// The output fragment.
typedef Fragment_ OutputFragment;
/// Ctor.
CUTLASS_DEVICE Copy() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
/// Transform a fragment.
template <typename InputFragment_>
CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
if (sizeof(typename Fragment_::Element) == 8) {
uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
dst_ptr[i] = src_ptr[i];
}
} else {
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
dst_ptr[i] = src_ptr[i];
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,126 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Helpers for printing cutlass/core objects
*/
#pragma once
#include <iosfwd>
#include <typeinfo>
#include "cutlass/coord.h"
#include "cutlass/vector.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Rank>
std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
for (int i = 0; i < Rank; ++i) {
out << (i ? ", " : "") << coord.idx[i];
}
return out;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
template <typename T>
struct ScalarIO {
/// Value to print
T value;
/// Default ctor
ScalarIO() { }
/// Constructs from a value
ScalarIO(T value): value(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Default printing to ostream
template <typename T>
inline std::ostream &operator<<(std::ostream &out, ScalarIO<T> const &scalar) {
return out << scalar.value;
}
/// Printing to ostream of int8_t as integer rather than character
template <>
inline std::ostream &operator<<(std::ostream &out, ScalarIO<int8_t> const &scalar) {
return out << int(scalar.value);
}
/// Printing to ostream of uint8_t as integer rather than character
template <>
inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scalar) {
return out << unsigned(scalar.value);
}
/// Printing to ostream of vector of 1b elements
template <>
inline std::ostream &operator<<(
std::ostream &out,
ScalarIO<cutlass::Vector<cutlass::bin1_t, 32> > const &scalar) {
for (int i = 0; i < 32; i++) {
out << int(scalar.value[i]);
out << ((i != 31) ? ", " : "");
}
return out;
}
/// Printing to ostream of vector of 4b signed integer elements
template <>
inline std::ostream &operator<<(
std::ostream &out,
ScalarIO<cutlass::Vector<cutlass::int4_t, 8> > const &scalar) {
for (int i = 0; i < 8; i++) {
out << int(scalar.value[i]);
out << ((i != 7) ? ", " : "");
}
return out;
}
/// Printing to ostream of vector of 4b unsigned integer elements
template <>
inline std::ostream &operator<<(
std::ostream &out,
ScalarIO<cutlass::Vector<cutlass::uint4_t, 8> > const &scalar) {
for (int i = 0; i < 8; i++) {
out << unsigned(scalar.value[i]);
out << ((i != 7) ? ", " : "");
}
return out;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,76 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Basic include for CUTLASS macros
*/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_MAJOR 1
#define CUTLASS_MINOR 2
#define CUTLASS_PATCH 0
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
#define CUTLASS_DEVICE __forceinline__ __device__
#elif defined(__CUDACC_RTC__)
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
#define CUTLASS_DEVICE __forceinline__ __device__
#else
#define CUTLASS_HOST_DEVICE
// CUTLASS_DEVICE is an error if not compiling device code
#endif
#define CUTLASS_ASSERT(x) assert(x)
#include "cutlass/util/performance_tuning.h"
// A small helper class to dump a type at compile time
// Usage:: DumpType<Class>::Class
template <typename T>
struct DebugType {};
template <typename T>
void DebugTypeFunc(T const& t) {
T::t;
}
// A small helper class to dump a compile time constant at compile time
// Usage: DumpValue<Class::kConstant>::kConstant
template <int Value>
struct DebugValue {};
namespace cutlass {
/// NVIDIA GPU Warp size
static const int kWarpSize = 32;
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -1,276 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines Fragment, a statically-sized array for storing parts of matrices within a
thread's registers.
*/
#pragma once
#include <assert.h>
#include "cutlass/shape.h"
#include "cutlass/util/cutlass_math.h"
#include "cutlass/vector.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup fragment_concept Fragment Concept
@{
\ref fragment_concept is a statically sized array for storing parts of tiles held by individual CUDA
threads.
@par \ref fragment_concept
Types satisfying \ref fragment_concept define the following members
- <b>Element</b> - type of each access held within the fragment
- <b>kElements</b> - number of elements stored by the fragment
- <b>clear()</b> - overwrites the fragment storage with zeros
- <b>Element & operator[](int i)</b> - by-reference access of the ith element
- <b>Element const & operator[](int i) const</b> - const by-reference access of the ith element
@}
*/
///////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup fragment_iterator_concept Fragment Iterator Concept
@{
\ref fragment_iterator_concept provides structured access to the elements within a fragment with an
optional bitcast to the desired access type
@par \ref fragment_iterator_concept
Types satisfying \ref fragment_iterator_concept define the following members
- <b>AccessType& operator[](int i)</b> - provides access to the ith element of the fragment
- <b>AccessType& at(int d, int h, int w, int c)</b> - applies \ref layout_concept to fragment and
provides access to element at (d, h, w, c)
@}
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int alignment>
struct StorageType {
typedef uint64_t Type;
};
template <>
struct StorageType<4> {
typedef uint32_t Type;
};
template <>
struct StorageType<2> {
typedef uint16_t Type;
};
template <>
struct StorageType<1> {
typedef uint8_t Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief A template defining \ref fragment_concept
* @concept{fragment_concept}
*/
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
struct Fragment : public AlignedStruct<kAlignment_> {
/// Make sure the alignment makes sense wrt the size of elements.
static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
/// Alignment must be a power of two
static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
/// This class.
typedef Fragment<Element_, kElements_> This_;
/// The element.
typedef Element_ Element;
/// The number of elements.
static int const kElements = kElements_;
/// Alignment
static int const kAlignment = kAlignment_;
/// Clear a fragment.
CUTLASS_HOST_DEVICE void clear() {
// Avoid element-wise access for sub 32b element type
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
ptr[i] = uint64_t(0);
}
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
ptr[i] = uint32_t(0);
}
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
ptr[i] = uint16_t(0);
}
} else {
for (int i = 0; i < kElements; ++i) {
storage[i] = 0;
}
}
}
/// The accessor.
CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast<Element*>(storage)[i]; }
/// The accessor.
CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
return reinterpret_cast<Element const*>(storage)[i];
}
private:
/// Storage type to use for Elements
typedef typename StorageType<kAlignment_>::Type StorageType;
/// Number of elements in the storage
static int const kStorageCount =
(sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
/// The storage.
StorageType storage[kStorageCount];
/// Ensure that there's enough storage for all elements
static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief A template defining \ref fragment_iterator_concept
* @concept{fragment_iterator_concept}
*/
template <typename Fragment_, typename Iterations_, typename AccessType_>
struct FragmentIterator {
/// This class.
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
/// The fragment.
typedef Fragment_ Fragment;
/// The number of iterations.
typedef Iterations_ Iterations;
/// The access type.
typedef AccessType_ AccessType;
/// The element.
typedef typename Fragment::Element Element;
/// The number of elements per access.
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
/// The shape of the the fragment.
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
/// The linear strides for iterations.
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
/// Ctor.
template <typename OtherFragment_>
CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
: pointer(reinterpret_cast<Element*>(&fragment[offset])) {
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
return reinterpret_cast<AccessType const&>(pointer[imm]);
}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
return reinterpret_cast<AccessType&>(pointer[imm]);
}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType& operator[](int i) {
return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
}
/// Is the iterator valid?
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
/// The pointer.
Element* pointer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Fragment_, typename Iterations_, typename AccessType_>
struct FragmentConstIterator {
/// This class.
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
/// The fragment.
typedef Fragment_ Fragment;
/// The number of iterations.
typedef Iterations_ Iterations;
/// The access type.
typedef AccessType_ AccessType;
/// The element.
typedef typename Fragment::Element Element;
/// The number of elements per access.
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
/// The shape of the the fragment.
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
/// The linear strides for iterations.
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
/// Ctor.
template <typename OtherFragment_>
CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
: pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
}
/// Create from non-constant FragmentIterator
CUTLASS_HOST_DEVICE FragmentConstIterator(
FragmentIterator<Fragment_, Iterations_, AccessType_> const& rhs_)
: pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
return reinterpret_cast<AccessType const&>(pointer[imm]);
}
/// The accessor.
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
}
/// Is the iterator valid?
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
/// The pointer.
Element const* pointer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,159 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines multiply-add operations on fragments within a thread.
*/
#pragma once
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template < typename ScalarAlphaBeta_,
typename ScalarAccum_,
bool fragMul2 = true /*number of element per fragment is multiple of 2*/
>
struct FragmentMultiplyAdd {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
/// The type for alpha and beta
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
/// The type for accumlator
typedef ScalarAccum_ ScalarAccum;
/// Ctor.
CUTLASS_DEVICE FragmentMultiplyAdd() {}
/// Multiply : d = a*b.
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];
for (int k = 1; k < kReduction; ++k) {
d[j] += b[j * kReduction + k];
}
d[j] = a * ScalarAlphaBeta(d[j]);
}
#endif
}
/// Multiply : d = a*b + c.
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a,
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
d[j] = b[j * kReduction + 0];
for (int k = 1; k < kReduction; ++k) {
d[j] += b[j * kReduction + k];
}
d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
}
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
template <>
struct FragmentMultiplyAdd<half, half, true> {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
/// The type for alpha and beta
typedef half ScalarAlphaBeta;
/// The type for accumlator
typedef half ScalarAccum;
/// Ctor.
CUTLASS_DEVICE FragmentMultiplyAdd() {}
/// Multiply : d = a*b.
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
// The input.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
}
#endif
}
/// Multiply : d = a*b + c.
template <typename FragmentB_, typename FragmentCd_>
CUTLASS_DEVICE void multiply_add(half a,
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
// The inputs.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
// Assemble a half2 from a.
__half2 const a_half2 = __half2half2(a);
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
}
#endif
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,58 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines abstractions for efficiently clearing accumulator tiles.
*/
#pragma once
#include "cutlass/vector.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kLanes_ = 1>
struct ClearAccumulators {
/// The shared storage.
struct SharedStorage {};
/// Ctor.
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
/// Ctor.
CUTLASS_DEVICE ClearAccumulators() {}
/// Clear the fragment.
template <typename Fragment_>
CUTLASS_DEVICE void clear(Fragment_& fragment) {
fragment.clear();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,67 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief device level GEMM implemented by more than one kernels.
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cuda.h>
#endif
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
namespace cutlass {
namespace gemm {
template<typename DeviceGemmTraits_ >
struct DeviceGemm {
/// The Traits
typedef DeviceGemmTraits_ Traits;
/// Use the params object defined in traits
typedef typename Traits::Params Params;
/// Support for NVRTC
#if !defined(__CUDACC_RTC__)
/// Launch the kernels in order
static __host__ cudaError_t launch(Params const& params) {
Traits::GemmTraits::KernelClass::launch(params.GemmParams);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
return err;
Traits::ReductionTraits::KernelClass::launch(params.ReductionParams);
return cudaGetLastError();
}
#endif
///
/// Methods
///
/// Ctor.
CUTLASS_DEVICE DeviceGemm() {}
};
} // namespace device_gemm
} // namespace cutalss

View File

@ -1,170 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <assert.h>
#include "cutlass/gemm/device_gemm.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/gemm/gemm_desc.h"
#include "tools/util/type_traits.h"
#include <iostream>
namespace cutlass {
namespace gemm {
template <
/// The Tratis for the first kernel
typename GemmTraits_,
/// The Traits for the second kernel
typename ReductionTraits_
>
struct SplitkPIGemmTraits {
typedef GemmTraits_ GemmTraits;
typedef ReductionTraits_ ReductionTraits;
typedef SplitkPIGemmTraits<GemmTraits_, ReductionTraits_> This_;
typedef typename cutlass::gemm::DeviceGemm<This_> KernelClass;
///
typedef typename GemmTraits::Index Index;
///
typedef typename ReductionTraits::ScalarAlphaBeta Scalar;
///
typedef typename GemmTraits::ScalarA ScalarA;
///
typedef typename GemmTraits::ScalarB ScalarB;
///
typedef typename GemmTraits::ScalarD ScalarAccum;
///
typedef typename ReductionTraits::ScalarC ScalarC;
///
typedef typename ReductionTraits::ScalarD ScalarD;
/// The layout of A. can be deduced from the layout set in batched gemm
static MatrixLayout::Kind const kLayoutA = GemmTraits::kLayoutA;
/// The layout of B. can be deduced from the layout set in batched gemm
static MatrixLayout::Kind const kLayoutB = GemmTraits::kLayoutB;
struct Params {
/// The dimensions of the GEMM in K, N, M order
GemmCoord problem_size;
/// Check if params are init
bool problem_size_initialized;
/// The pointer to workspace memory
ScalarAccum *workspace_ptr;
///
int workspace_size;
/// The Params for the first kernel
typename GemmTraits::Params GemmParams;
/// The Params for the second kernel
typename ReductionTraits::Params ReductionParams;
/// ctor
Params() :
workspace_size(0),
problem_size_initialized(false) {}
/// ctor
Params(Index m_,
Index n_,
Index k_
):
problem_size(k_, n_, m_, 1),
workspace_size(0),
problem_size_initialized(true) {
}
/// init problem is needed if using default ctor
void init_problem(Index m_,
Index n_,
Index k_){
problem_size = GemmCoord(k_, n_, m_, 1);
problem_size_initialized = true;
}
int initialize(Scalar alpha_,
ScalarA const* d_a_,
Index lda_,
ScalarB const* d_b_,
Index ldb_,
Scalar beta_,
ScalarC const* d_c_,
Index ldc_,
ScalarD* d_d_,
Index ldd_,
ScalarAccum *workspace_ptr_) {
workspace_ptr = workspace_ptr_;
//call GemmTraits (first kernel) param
//for the first kernel A is A, B is B, C and D are workspace
//alpha is one, beta is zero, partitionK_count is reductionTraits::reductionSize
typename cutlass::gemm::GemmDesc<typename GemmTraits::ScalarA,
typename GemmTraits::ScalarB,
typename GemmTraits::ScalarC,
typename GemmTraits::ScalarD,
typename GemmTraits::Epilogue::Scalar>
desc(
problem_size,
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(1.0f), /*alpha*/
TensorRef<typename GemmTraits::ScalarA const, 2>(d_a_, lda_),
TensorRef<typename GemmTraits::ScalarB const, 2>(d_b_, ldb_),
typename cutlass::TypeTraits<typename GemmTraits::Epilogue::Scalar>::host_type(0.0f), /*beta*/
TensorRef<typename GemmTraits::ScalarC const, 2>(workspace_ptr, problem_size.m()), /*m = ldc, workspace is not transposed and is packed*/
TensorRef<typename GemmTraits::ScalarD, 2>(workspace_ptr, problem_size.m()) /*m = ldd, workspace is not transposed and is packed*/
);
GemmParams.initialize(desc, ReductionTraits::ReductionSize);
//call batched reduction (second kernel) param
ReductionParams.initialize(problem_size.m(), /*m*/
problem_size.n(), /*n*/
alpha_, /*alpha*/
beta_, /*beta*/
problem_size.n() * problem_size.m() /*reduction_stride*/,
workspace_ptr,
problem_size.m(),
d_c_,
ldc_,
d_d_,
ldd_);
return 0;
}
// workspace will be used to store D (output) from the first gemm kernel (not D of the entire gemm)
// note typedef typename GemmTraits::ScalarD ScalarAccum;
// workspace of size of M * N * Reduction
int required_workspace_memory_in_byte(){
assert(problem_size_initialized == true);
workspace_size = problem_size.n() * problem_size.m() * ReductionTraits::ReductionSize * static_cast<int>(sizeof(ScalarAccum));
return workspace_size;
}
};
};
} // namespace device_gemm
} // namespace cutalss

View File

@ -1,134 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural traits of double-precision GEMM.
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_epilogue.h"
#include "cutlass/gemm/gemm_epilogue_traits.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/gemm/gemm_shared_tile.h"
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The tile size for threadblock-level GEMM (K-by-N-by-M).
typename OutputTile_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of scalars per LDG for B.
int kScalarsPerLdgB_ = 1>
struct DgemmConfig
: public GemmConfig<
/// The scalar type for A.
double,
/// The scalar type for B.
double,
/// The scalar type for C.
double,
/// The scalar type for D.
double,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, double, double, double>,
/// The number of scalars per LDG for A.
kScalarsPerLdgA_,
/// The number of scalars per STS for A.
kScalarsPerLdgA_,
/// The number of scalars per LDS for A.
2,
/// The number of scalars per LDG for B.
kScalarsPerLdgB_,
/// The number of scalars per STS for B.
kScalarsPerLdgB_,
/// The number of scalars per LDS for B.
2,
/// The number of scalars per LDG for C and STG for D.
1,
/// The number of scalars per STS for D.
2,
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
2,
/// kResidueSeparate
false,
/// kResidueInPrologue
false,
/// kLaunchBounds
false
>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The tile size for threadblock-level GEMM (K-by-N-by-M)
typename OutputTile_ = Shape<8, 64, 128>,
/// The functor to use in the epilogue.
typename EpilogueFunctor_ = LinearScaling<double>,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<8, 8, 8>,
/// The number of doubles loaded in one LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of doubles loaded in one LDG for B.
int kScalarsPerLdgB_ = 1,
/// The index.
typename Index_ = int,
/// The DGEMM config.
typename GemmConfig_ =
DgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
/// The traits class for the epilogue.
typename GemmEpilogueTraits_ =
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
struct DgemmTraits : public SimplifiedGemmTraits<
// The layout for A.
kLayoutA_,
// The layout for B.
kLayoutB_,
// The config.
GemmConfig_,
// The epilogue.
GemmEpilogue<GemmEpilogueTraits_>,
// The index.
Index_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,83 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template implementing matrix multiply-add operations on fragments.
*/
#pragma once
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Template performing matrix multiply-add operation within a thread
template <typename ThreadGemmShape_,
typename ThreadsPerWarp_>
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
/// The shape of a thread-leveel matrix multiply accumulate.
typedef ThreadGemmShape_ ThreadGemmShape;
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of accumulators per warp.
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
/// The type for A. specialized to half
typedef half ScalarA;
/// The fragment for A.
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
/// The type for B. specialized to half
typedef half ScalarB;
/// The fragment for B.
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
/// The type for C and D. specialized to float
typedef float ScalarC;
/// The accumulators.
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
/// Multiply : d = a*b + c.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,152 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies structural properties of single-precision GEMM where any number of the input/output
could be fp16 or fp32. The accumulator type stays in fp32
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_epilogue.h"
#include "cutlass/gemm/gemm_epilogue_traits.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/gemm/gemm_shared_tile.h"
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/fp16_sgemm_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_,
/// The type for A
typename ScalarA_,
/// The type for B
typename ScalarB_,
/// The type for C
typename ScalarC_,
/// The type for D
typename ScalarD_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of scalars per LDG for B.
int kScalarsPerLdgB_ = 1>
struct Fp16SgemmConfig : public GemmConfig<
/// The scalar type for A.
ScalarA_,
/// The scalar type for B.
ScalarB_,
/// The scalar type for C.
ScalarC_,
/// The scalar type for D.
ScalarD_,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>,
/// The number of scalars per LDG for A.
kScalarsPerLdgA_,
/// The number of scalars per STS for A.
kScalarsPerLdgA_,
/// The number of scalars per LDS for A.
4,
/// The number of scalars per LDG for B.
kScalarsPerLdgB_,
/// The number of scalars per STS for B.
kScalarsPerLdgB_,
/// The number of scalars per LDS for B.
4,
/// The number of scalars per LDG for C and STG for D.
1,
/// The number of scalars per STS for D.
4,
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
2> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_ = Shape<8, 128, 128>,
/// The type for A
typename ScalarA_ = half,
/// The type for B
typename ScalarB_ = half,
/// The type for C
typename ScalarC_ = half,
/// The type for D
typename ScalarD_ = half,
/// the Type for alpha and beta,
typename Scalar_ = half,
/// The functor to use in the epilogue.
typename EpilogueFunctor_ = LinearScaling<Scalar_, FragmentMultiplyAdd<Scalar_, float/*accumulator type*/> >,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<8, 8, 8>,
/// The number of floats loaded in one LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of floats loaded in one LDG for B.
int kScalarsPerLdgB_ = 1,
/// The index.
typename Index_ = int,
/// The SGEMM config.
typename GemmConfig_ =
Fp16SgemmConfig<OutputTile_,
ThreadGemmShape_,
ScalarA_,
ScalarB_,
ScalarC_,
ScalarD_,
kScalarsPerLdgA_,
kScalarsPerLdgB_>,
/// The traits class for the epilogue.
typename GemmEpilogueTraits_ =
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
struct Fp16SgemmSgemmTraits : public SimplifiedGemmTraits<
// The layout for A.
kLayoutA_,
// The layout for B.
kLayoutB_,
// The config.
GemmConfig_,
// The epilogue.
GemmEpilogue<GemmEpilogueTraits_>,
// The index.
Index_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,355 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements a software-pipelined efficient GEMM.
*/
#pragma once
#if !defined(__CUDACC_RTC__)
#include <cuda.h>
#endif
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel with launch bounds specified
template <typename Gemm_>
__global__ __launch_bounds__(Gemm_::kThreads)
void gemm_kernel(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
// Construct the GEMM object.
Gemm_ gemm(params, shared_storage);
// Run GEMM.
gemm.multiply_add();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel without launch bounds specified
template <typename Gemm_>
__global__ /* __launch_bounds__(Gemm_::kThreads) */
void gemm_kernel_nolb(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
// Construct the GEMM object.
Gemm_ gemm(params, shared_storage);
// Run GEMM.
gemm.multiply_add();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for launching the GEMM kernel with or without launch bounds
template <typename Gemm, bool WithLaunchBounds>
struct Launch {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for launching the GEMM kernel with or without launch bounds
template <typename Gemm>
struct Launch<Gemm, false> {
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTraits_>
struct Gemm {
/// This class.
typedef Gemm<GemmTraits_> This_;
/// The traits.
typedef GemmTraits_ Traits;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// The scalar for A.
typedef typename Traits::ScalarA ScalarA;
/// The scalar for B.
typedef typename Traits::ScalarB ScalarB;
/// The scalar in the epilogue.
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
/// The scalar for C.
typedef typename Traits::Epilogue::ScalarC ScalarC;
/// The scalar for D.
typedef typename Traits::Epilogue::ScalarD ScalarD;
/// The index.
typedef typename Traits::Index Index;
/// Define the mainloop iteration size
typedef typename Traits::MultiplyAdd MultiplyAdd;
/// The number of threads.
static int const kThreads = Traits::GemmConfig::kThreads;
// Number of warp-level multiply-accumulate steps executed by each warp.
static Index const kWarpGemmSteps =
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
/// Use the params object defined in traits
typedef typename Traits::Params Params;
//
// Static function members
//
/// Support for NVRTC
#if !defined(__CUDACC_RTC__)
/// Launch the kernel.
static __host__ cudaError_t launch(Params const& params,
cudaStream_t stream = cudaStreamDefault) {
// Launch the kernel.
Launch<This_, GemmTraits_::GemmConfig::kLaunchBounds>(
params, params.grid, params.block, stream);
return cudaGetLastError();
}
/// Launch the kernel.
static __host__ cudaError_t launch(CUfunction kernel,
Params const& params,
CUstream stream = CU_STREAM_LEGACY) {
// Launch the kernel.
void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(&params))};
CUresult result = cuLaunchKernel(
kernel,
params.grid.x, params.grid.y, params.grid.z,
params.block.x, params.block.y, params.block.z,
0, stream, params_, 0);
if (result != CUDA_SUCCESS) {
return cudaErrorLaunchFailure;
}
return cudaSuccess;
}
#endif
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
/// Computes a warp-level GEMM on data held in shared memory
template <bool Residue, bool LastIteration>
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
typename Traits::SharedStream& shared_load_stream,
typename MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
// If residue portion and not calculating residue in prolog, update residue predicates now.
if (Residue && outer_k <= Traits::OutputTile::kD) {
global_to_shared_stream.residue(outer_k);
}
// Load data for the next iteration of the main loop (unless it's the last iteration).
if (!LastIteration) {
global_to_shared_stream.copy();
}
CUTLASS_PRAGMA_UNROLL
for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy(step + 1);
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
MultiplyAdd multiply_add;
// Do the math on the fragments of the current iteration.
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
accumulators);
}
// Make sure the data from shared memory has been entirely consumed.
Traits::shared_load_fence(true);
// Commit the data in shared memory for A/B.
if (!LastIteration) {
global_to_shared_stream.commit();
}
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
if (!LastIteration) {
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
// Trigger the copy from shared memory for the next loop iteration.
shared_load_stream.copy(0);
}
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(kWarpGemmSteps - 1);
// Do the math on the fragments of the current iteration.
MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
shared_load_stream.fragment_b(kWarpGemmSteps - 1),
accumulators,
accumulators);
}
/// Do the GEMM.
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle;
Coord<3> threadblock_offset =
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
// We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators;
// Get the bounds for each thread, it maybe different than problem_size
Coord<3> bounds = block_swizzle.get_threadblock_bounds(params.problem_size,
params.partitionK_range);
// The streams to read A/B from global memory to shared memory.
typename Traits::GlobalLoadStream global_to_shared_stream(
params.global_to_shared_stream,
shared_storage.main_loop.global_to_shared_stream,
shared_storage.main_loop.threadblock_tile.reference(),
bounds,
threadblock_offset);
// update A and B pointer offset based on batch_id and batch_stride_offset
global_to_shared_stream.add_batch_offset(block_swizzle.get_batch_id());
// Create the accumulator clear.
ClearAccumulators clear;
// Deal with residue in prolog.
// global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
global_to_shared_stream.move_to_residue(bounds[0], Traits::OutputTile::kD);
// Fetch the fragments for A and B from global memory.
global_to_shared_stream.copy();
// Copy the elements to shared memory (after transformation if needed).
global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
// Rollback to the beginning of the first tile (if residue exists).
// global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
global_to_shared_stream.rollback(bounds[0] % Traits::OutputTile::kD);
// The stream of data from shared memory to fragments.
typename Traits::SharedStream shared_load_stream(
params.shared_stream,
shared_storage.main_loop.threadblock_tile.reference());
// Trigger the copy from shared memory for the 1st stream.
shared_load_stream.copy(0);
// Allocate the accumulators.
typename MultiplyAdd::Accumulators accumulators;
// Clear the accumulators.
clear.clear(accumulators);
// Initial index
// Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
// problem_size[0] might be bigger than bounds[0]
Index outer_k = bounds[0] - Traits::OutputTile::kD;
// Check if we are computing residue in prolog or not.
if (Traits::GemmConfig::kResidueInProlog) {
// Execute all mainloop iterations but the last one.
CUTLASS_GEMM_LOOP
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
// Don't load data for the last "residue" portion since we've already computed the residue.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, true>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
} else {
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
// consideration for K-residue or predicate updates. This improves the steady state of some
// kernels.
if (Traits::GemmConfig::kResidueSeparate) {
CUTLASS_GEMM_LOOP
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<false, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Execute remaining tiles with K-residue predicate updates enabled.
CUTLASS_GEMM_LOOP
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
consume_tile<true, false>(
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
}
}
// Epilogue.
typedef typename Traits::Epilogue Epilogue;
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
}
//
// Data members
//
/// The params.
Params const& params;
/// The shared storage.
SharedStorage& shared_storage;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,145 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines properties of GEMM computation that impose some constraints on caller.
*/
#pragma once
#include "cutlass/shape.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The scalar type for A.
typename ScalarA_,
/// The scalar type for B.
typename ScalarB_,
/// The scalar type for C.
typename ScalarC_,
/// The scalar type for D.
typename ScalarD_,
/// The threadblock tile size for the GEMM KxNxM.
typename OutputTile_,
/// The functor to do the math.
typename MultiplyAdd_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_,
/// The number of scalars per STS for A.
int kScalarsPerStsA_,
/// The number of scalars per LDG for A.
int kScalarsPerLdsA_,
/// The number of scalars per LDG for B.
int kScalarsPerLdgB_,
/// The number of scalars per STS for B.
int kScalarsPerStsB_,
/// The number of scalars per LDS for B.
int kScalarsPerLdsB_,
/// The number of scalars per LDG for C and STG for D.
int kScalarsPerLdgCAndStgD_,
/// The number of scalars per STS for D.
int kScalarsPerStsD_,
/// The number of scalars per LDS for D.
int kScalarsPerLdsD_,
/// The number of stages in shared memory to do single/double/triple-buffering.
int kStages_,
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
bool kResidueSeparate_ = false,
/// Is residue performed in prologue?
bool kResidueInProlog_ = false,
/// If true, kernel is launched with CUDA launch bounds specified
bool kLaunchBounds_ = true>
struct GemmConfig {
//
/// The scalar for A.
typedef ScalarA_ ScalarA;
/// The scalar for B.
typedef ScalarB_ ScalarB;
/// The scalar for C.
typedef ScalarC_ ScalarC;
/// The scalar for D.
typedef ScalarD_ ScalarD;
/// The tile.
typedef OutputTile_ OutputTile;
/// The functor to do D = A*B + C.
typedef MultiplyAdd_ MultiplyAdd;
/// The shape of the instruction.
typedef typename MultiplyAdd::InstructionShape InstructionShape;
/// The shape of warp-level GEMM
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
/// The accumulators.
typedef typename MultiplyAdd::Accumulators Accumulators;
/// The number of warps.
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
/// The default warp size (32 threads per warp).
static int const kWarpSize = cutlass::kWarpSize;
/// The numnber of threads.
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
/// The number of scalars per LDG/STS/LDS for A.
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
static int const kScalarsPerStsA = kScalarsPerStsA_;
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
/// The number of scalars per LDG/STS/LDS for B.
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
static int const kScalarsPerStsB = kScalarsPerStsB_;
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
/// The number of scalars per LDG for C.
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
/// The number of scalars per STS/LDS/STG for D.
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
static int const kScalarsPerStsD = kScalarsPerStsD_;
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
/// The number of accumulators that are going to be fed from one LDS A/B.
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
/// The number of stages in shared memory to implement double, triple, more-buffering.
static int const kStages = kStages_;
/// If true, mainloop is instantiated twice. The first instantiation contains no predicate
// updates and is more efficient for some kernels. If false, only a single mainloop is
// instantaited.
static bool const kResidueSeparate = kResidueSeparate_;
/// If true, residue is computed in the prologue.
static bool const kResidueInProlog = kResidueInProlog_;
/// If true, kernel is launched with launch bounds specified
static bool const kLaunchBounds = kLaunchBounds_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,209 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief GemmCoord is a structure derived from Coord<4> that specifies a location within the
coordinate system of a GEMM problem.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// GemmCoord is a structure derived from Coord<4> that specifies a location within the
/// coordinate space of a GEMM problem.
struct GemmCoord : public Coord<4, int> {
/// Integer-valued index
typedef int Index;
/// Base type is a Coord of rank=4
typedef Coord<4, Index> Base;
/// GEMM K dimension - inner dimension of the GEMM problem
static int const kK = 0;
/// GEMM N dimension - columns of the output C matrix
static int const kN = 1;
/// GEMM M dimension - rows of the output C matrix
static int const kM = 2;
/// Batch dimension - for generalizing to larger problems
static int const kBatch = 3;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
GemmCoord() { }
/// Constructs from Coord<3> and a batch
CUTLASS_HOST_DEVICE
GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { }
/// Constructs from Coord<4>
CUTLASS_HOST_DEVICE
GemmCoord(Coord<4, Index> const &coord): Base(coord) { }
/// Constructs from an array of coordinate elements
CUTLASS_HOST_DEVICE
GemmCoord(Index coord[4]): Base(coord) { }
/// Helper to construct from a K, N, M, batch variables
CUTLASS_HOST_DEVICE
GemmCoord(Index k, Index n, Index m, Index batch = 0): Base(make_Coord(k, n, m, batch)) { }
/// Returns the GEMM M coordinate
CUTLASS_HOST_DEVICE
Index const & m() const { return this->at(kM); }
/// Returns reference to the GEMM M coordinate
CUTLASS_HOST_DEVICE
Index & m() { return this->at(kM); }
/// Returns the GEMM N coordinate
CUTLASS_HOST_DEVICE
Index const & n() const { return this->at(kN); }
/// Returns reference to the GEMM N coordinate
CUTLASS_HOST_DEVICE
Index & n() { return this->at(kN); }
/// Returns the GEMM K coordinate
CUTLASS_HOST_DEVICE
Index const & k() const { return this->at(kK); }
/// Returns reference to the GEMM K coordinate
CUTLASS_HOST_DEVICE
Index & k() { return this->at(kK); }
/// Returns the GEMM batch coordinate
CUTLASS_HOST_DEVICE
Index const & batch() const { return this->at(kBatch); }
/// Returns reference to the GEMM batch coordinate
CUTLASS_HOST_DEVICE
Index & batch() { return this->at(kBatch); }
/// Obtains a Coord<3> from GemmCoord
CUTLASS_HOST_DEVICE
Coord<3> knm() const {
return make_Coord(k(), n(), m());
}
/// Obtains a Coord<2> from GemmCoord
CUTLASS_HOST_DEVICE
Coord<2> nm() const {
return make_Coord(n(), m());
}
/// Obtains a Coord<2> from GemmCoord
CUTLASS_HOST_DEVICE
Coord<2> mn() const {
return make_Coord(m(), n());
}
/// Obtains a Coord<2> from GemmCoord
CUTLASS_HOST_DEVICE
Coord<2> km() const {
return make_Coord(k(), m());
}
/// Obtains a Coord<2> from GemmCoord
CUTLASS_HOST_DEVICE
Coord<2> kn() const {
return make_Coord(k(), n());
}
//
// Coord operators
//
/// Element-wise addition
CUTLASS_HOST_DEVICE
GemmCoord operator+(Base const& b) const {
return GemmCoord(Base::operator+(b));
}
/// Element-wise subtraction
CUTLASS_HOST_DEVICE
GemmCoord operator-(Base const& b) const {
return GemmCoord(Base::operator-(b));
}
/// Element-wise multiplication
CUTLASS_HOST_DEVICE
GemmCoord operator*(Base const& b) const {
return GemmCoord(Base::operator*(b));
}
/// Element-wise division
CUTLASS_HOST_DEVICE
GemmCoord operator/(Base const& b) const {
return GemmCoord(Base::operator/(b));
}
/// In-place addition
CUTLASS_HOST_DEVICE
GemmCoord& operator+=(Base const& b) {
Base::operator+=(b);
return *this;
}
/// In-place subtraction
CUTLASS_HOST_DEVICE
GemmCoord& operator-=(Base const& b) {
Base::operator-=(b);
return *this;
}
/// In-place multiplication
CUTLASS_HOST_DEVICE
GemmCoord& operator*=(Base const& b) {
Base::operator*=(b);
return *this;
}
/// In-place division
CUTLASS_HOST_DEVICE
GemmCoord& operator/=(Base const& b) {
Base::operator/=(b);
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,205 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements a software-pipelined efficient GEMM.
*/
#pragma once
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/gemm_coord.h"
namespace cutlass {
namespace gemm {
/// GEMM problem description
template <
/// Source accumulator matrix type
typename AType_,
/// Destination accumulator type
typename BType_,
/// Source accumulator matrix type
typename CType_,
/// Destination accumulator type
typename DType_,
/// Scalar type for alpha and beta
typename SType_,
/// Index type for dimensions and strides
typename Index_ = int
> struct GemmDesc {
//
// Type definitions
//
/// Index type for dimensions and strides
typedef Index_ Index;
/// Source accumulator matrix type
typedef AType_ AType;
/// Tensor reference to A operand
typedef TensorRef<AType const, 2> TensorRefA;
/// Destination accumulator type
typedef BType_ BType;
/// Tensor reference to B operand
typedef TensorRef<BType const, 2> TensorRefB;
/// Source accumulator matrix type
typedef CType_ CType;
/// Tensor reference to C operand
typedef TensorRef<CType const, 2> TensorRefC;
/// Destination accumulator type
typedef DType_ DType;
/// Tensor reference to D operand
typedef TensorRef<DType, 2> TensorRefD;
/// Scalar type for alpha and beta
typedef SType_ SType;
//
// Data members
//
/// The dimensions of the GEMM.
GemmCoord problem_size;
/// The alpha scaling values.
SType alpha;
/// The source matrix A.
TensorRefA A;
/// batch stride for A operand
long long batch_stride_A;
/// The source matrix B.
TensorRefB B;
/// batch stride for B operand
long long batch_stride_B;
/// The beta scaling values.
SType beta;
/// The source matrix C.
TensorRefC C;
/// batch stride for C operand
long long batch_stride_C;
/// The destination matrix D.
TensorRefD D;
/// batch stride for D operand
long long batch_stride_D;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {}
/// Constructor for basic GEMM with batch count = 1
CUTLASS_HOST_DEVICE
GemmDesc(Coord<3> _problem_size,
SType _alpha,
TensorRefA const &_A,
TensorRefB const &_B,
SType _beta,
TensorRefC const &_C,
TensorRefD const &_D
):
problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
alpha(_alpha),
A(_A),
batch_stride_A(0),
B(_B),
batch_stride_B(0),
beta(_beta),
C(_C),
batch_stride_C(0),
D(_D),
batch_stride_D(0) {}
/// Constructor for basic GEMM with batch count = 1
CUTLASS_HOST_DEVICE
GemmDesc(GemmCoord _problem_size,
SType _alpha,
TensorRefA const &_A,
TensorRefB const &_B,
SType _beta,
TensorRefC const &_C,
TensorRefD const &_D
):
problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
alpha(_alpha),
A(_A),
batch_stride_A(0),
B(_B),
batch_stride_B(0),
beta(_beta),
C(_C),
batch_stride_C(0),
D(_D),
batch_stride_D(0) {
assert(_problem_size.batch() == 1);
}
/// Constructor for strided batch GEMM GEMM
CUTLASS_HOST_DEVICE
GemmDesc(GemmCoord _problem_size,
SType _alpha,
TensorRefA const &_A,
long long _batch_stride_A,
TensorRefB const &_B,
long long _batch_stride_B,
SType _beta,
TensorRefC const &_C,
long long _batch_stride_C,
TensorRefD const &_D,
long long _batch_stride_D
):
problem_size(_problem_size),
alpha(_alpha),
A(_A),
batch_stride_A(_batch_stride_A),
B(_B),
batch_stride_B(_batch_stride_B),
beta(_beta),
C(_C),
batch_stride_C(_batch_stride_C),
D(_D),
batch_stride_D(_batch_stride_D) {}
};
} // namespace gemm
} // namespace cutlass

View File

@ -1,223 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
with
the computed matrix product.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/coord.h"
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmEpilogueTraits_>
struct GemmEpilogue {
/// The traits class.
typedef GemmEpilogueTraits_ Traits;
/// The params.
typedef typename Traits::Params Params;
/// The shared storage.
typedef typename Traits::SharedStorage SharedStorage;
/// The output tile.
typedef typename Traits::OutputTile OutputTile;
/// The number of iterations.
typedef typename Traits::Iterations Iterations;
/// The accumulators.
typedef typename Traits::Accumulators Accumulators;
/// The scalar.
typedef typename Traits::Scalar Scalar;
/// The functor in charge of the math.
typedef typename Traits::Functor Functor;
/// We do not support 3D or 4D shapes.
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
/// The iterator for C in global memory.
typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
/// The transformer for C.
typedef typename Traits::GlobalTransformerC GlobalTransformerC;
/// The transformer for D.
typedef typename Traits::GlobalTransformerD GlobalTransformerD;
/// The iterator for D in global memory.
typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
/// The iterator to store D in shared memory.
typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
/// The shared store transformer for D.
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
/// The iterator to load D in shared memory.
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
/// The index.
typedef typename Traits::Index Index;
/// The scalar for C.
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
/// The scalar for D.
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
/// Ctor.
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
SharedStorage& shared_storage_,
Coord<3> const& _problem_size)
: params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {}
/// Execute the epilogue.
CUTLASS_DEVICE void epilogue(Accumulators& accumulators,
Coord<3> const& block = make_Coord(0, 0, 0),
int batch_id = 0) {
if (functor.source_required()) {
epilogue_with_or_without_beta<true>(accumulators, block, batch_id);
} else {
epilogue_with_or_without_beta<false>(accumulators, block, batch_id);
}
}
template <bool kSourceRequired>
CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators,
Coord<3> const& block,
int batch_id) {
// The C fragment.
typename GlobalLoadIteratorC::Fragment fragment_c;
// The transformed C fragment.
typename GlobalTransformerC::OutputFragment transformed_c;
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
// Compute pointer and predicate offsets for C and D global iterators.
int const pointer_offset =
((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
params.iterator_d.inc_advance) *
Iterations::kW +
params.stride_h) *
h;
int const predicate_offset =
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
params.iterator_d.predicate_inc_advance) *
Iterations::kW +
Traits::Delta::kH) *
h;
// The iterator to load the elements of the C matrix.
GlobalLoadIteratorC global_load_iterator(
params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
// update C pointer offset based on batch_id and batch_stride_offset
global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_C);
// The transformer for C.
GlobalTransformerC transformer_c;
// The transformer for D.
GlobalTransformerD transformer_d;
// The iterator to store into the D matrix.
GlobalStoreIteratorD global_store_iterator(
params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
// update D pointer offset based on batch_id and batch_stride_offset
global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_D);
SharedStoreTransformerD shared_store_transformer;
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
SharedStoreIteratorD shared_store_iterator(
params.shared_store_iterator_d,
reinterpret_cast<typename SharedStoreIteratorD::Scalar*>(shared_storage.data()));
SharedLoadStreamD shared_load_stream(
params.shared_load_stream_d,
reinterpret_cast<typename SharedLoadStreamD::Scalar*>(shared_storage.data()));
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
// Load the C matrix into fragment.
if (kSourceRequired) {
global_load_iterator.load_post_increment(fragment_c);
}
// Make sure we can write to shared memory.
shared_load_fence();
// Copy the accumulators to shared memory.
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
shared_store_iterator.store_post_increment(shared_store_transformed_d);
// Make sure the data is in shared memory.
shared_store_fence();
// Copy the accumulators back to registers from shared memory.
shared_load_stream.copy();
shared_load_stream.commit();
// Do the math.
typename GlobalTransformerD::InputFragment fragment_d;
if (kSourceRequired) {
// Transform C fragment.
transformer_c.transform(fragment_c, transformed_c);
// Do the math.
functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
} else {
functor.evaluate(shared_load_stream.fragment(), fragment_d);
}
// Transform D fragment.
typename GlobalTransformerD::OutputFragment global_transformed_d;
transformer_d.transform(fragment_d, global_transformed_d);
// Copy the results to global memory.
global_store_iterator.store_post_increment(global_transformed_d);
}
}
}
/// The memory fence for shared loads.
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
/// The memory fence for shared stores.
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
/// The params.
Params const& params;
/// The shared storage.
SharedStorage& shared_storage;
/// The dimensions of the GEMM.
Coord<3> problem_size;
// The functor.
Functor functor;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,371 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties of the GEMM epilogue.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/linear_scaling.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The output tile.
typename OutputTile_,
/// The accumulators.
typename Accumulators_,
/// The iterator to load C from global memory.
typename GlobalLoadIteratorC_,
/// The transformer for C.
typename GlobalTransformerC_,
/// The transformer for D.
typename GlobalTransformerD_,
/// The iterator to store D to global memory.
typename GlobalStoreIteratorD_,
/// The iterator to store D to shared memory.
typename SharedStoreIteratorD_,
/// The shared store transformer for D.
typename SharedStoreTransformerD_,
/// The stream to load D from shared memory.
typename SharedLoadStreamD_,
/// The number of iterations in the epilogue.
typename Iterations_,
/// The iterations strides.
typename Delta_,
/// The functor to be used in the epilogue.
typename Functor_,
/// The index.
typename Index_ = int>
struct GemmEpilogueTraits {
//
/// The output tile.
typedef OutputTile_ OutputTile;
/// The number of iterations.
/// The accumulators.
typedef Accumulators_ Accumulators;
/// The iterator for C in global memory.
typedef GlobalLoadIteratorC_ GlobalLoadIteratorC;
/// The transformer for C.
typedef GlobalTransformerC_ GlobalTransformerC;
/// The transformer for D.
typedef GlobalTransformerD_ GlobalTransformerD;
/// The iterator for D in global memory.
typedef GlobalStoreIteratorD_ GlobalStoreIteratorD;
/// The iterator to store D in shared memory.
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
/// The shared store transformer for D.
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
/// The stream to store D in shared memory.
typedef SharedLoadStreamD_ SharedLoadStreamD;
/// typedef typename GemmConfig::EpilogueIterations Iterations;
typedef Iterations_ Iterations;
/// The iterations strides.
typedef Delta_ Delta;
/// The functor in charge of the math.
typedef Functor_ Functor;
/// The index.
typedef Index_ Index;
/// The long index
typedef long long LongIndex;
/// We do not support 3D or 4D shapes.
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
/// The scalar.
typedef typename Functor::Scalar Scalar;
/// The scalar for C.
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
/// The scalar for D.
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
/// The params.
struct Params {
/// The strides for H and W in the different iterations of the epilogue.
Index stride_h, stride_w;
/// The params for the C iterator.
typename GlobalLoadIteratorC::Params iterator_c;
/// Batch stride for C matrix
LongIndex batch_stride_C;
/// The params for the D global iterator.
typename GlobalStoreIteratorD::Params iterator_d;
/// Batch stride for C matrix
LongIndex batch_stride_D;
/// The params for the D shared store iterator.
typename SharedStoreIteratorD::Params shared_store_iterator_d;
/// The params for the D shared load stream.
typename SharedLoadStreamD::Params shared_load_stream_d;
/// The functor params.
typename Functor::Params functor;
/// Setup the params.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
// The parameters for the functor.
int error_code = functor.initialize(desc);
if (error_code) {
return error_code;
}
// At the end of the H iteration, we jump over a number of columns.
this->stride_h = desc.D.leading_dim() * Delta::kH;
// Nothing to do here.
this->stride_w = 0;
// Setup the params for the global memory iterator for C.
error_code = iterator_c.initialize(desc.C.data(),
desc.C.leading_dim(),
desc.C.leading_dim(),
desc.problem_size[1],
stride_w,
Delta::kW);
batch_stride_C = desc.batch_stride_C;
if (error_code) {
return error_code;
}
// Setup the params for the global memory iterator for D.
error_code = iterator_d.initialize(desc.D.data(),
desc.D.leading_dim(),
desc.D.leading_dim(),
desc.problem_size[1],
stride_w,
Delta::kW);
batch_stride_D = desc.batch_stride_D;
return error_code;
}
};
/// The shared memory storage to exchange data.
union StreamSharedStorage {
// The storage for the store iterator.
typename SharedStoreIteratorD::SharedStorage store;
// The storage for the store iterator.
typename SharedLoadStreamD::SharedStorage load;
};
/// The shared memory to swizzle the data in the epilogue.
struct SharedStorage {
// The storage for the shared stream D.
StreamSharedStorage shared_stream;
//
//
//
CUTLASS_DEVICE
ScalarD* data() { return reinterpret_cast<ScalarD*>(&shared_stream.load); }
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
struct GemmEpilogueTraitsHelper {
/// The scalar.
typedef typename EpilogueFunctor_::Scalar Scalar;
/// The output tile.
typedef typename GemmConfig_::OutputTile OutputTile;
/// The number of iterations in the epilogue.
typedef Shape<1,
GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
GemmConfig_::kAccumulatorsPerLdsB,
GemmConfig_::kAccumulatorsPerLdsB>
Iterations;
// The iteration strides in the H/W dimension.
typedef Shape<0,
GemmConfig_::kAccumulatorsPerLdsB*(
GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
0>
Delta;
/// The functor to do the math in the epilogue.
typedef EpilogueFunctor_ Functor;
/// The traits class to build the iterator to store to shared memory for D.
typedef GemmSharedStoreTileDTraits<
// The pointer is float.
// typename Functor::Scalar,
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
// In this case Functor::ScalarAccum is needed
typename Functor::ScalarAccum,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The number of scalars per STS.
GemmConfig_::kScalarsPerStsD,
// The skew -- 128 / sizeof(ScalarD) / kScalarsPerStsD is the number of threads involved in
// a single STS. We divide by 2 as our objective is to add a skew to the odd threads to
// avoid bank conflicts between odd and even threads.
128 / sizeof(typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
GemmConfig_::kScalarsPerStsD>
SharedStoreTileTraits;
/// The iterator to store D to shared memory.
typedef TileStoreIterator<SharedStoreTileTraits,
typename SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorD;
/// The shared store transformer for D.
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
/// The traits class to build the iterator to load from shared memory for D.
typedef GemmSharedLoadTileDTraits<
// The pointer is float.
// typename Functor::Scalar,
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
// In this case Functor::ScalarAccum is needed
typename Functor::ScalarAccum,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The number of columns of the output tile written by iteration.
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsD,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
/// The iterator to load D from shared memory.
typedef TileLoadIterator<SharedLoadTileTraits,
typename SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorD;
/// The stream to load D.
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
/// The traits class to build the iterator to load data from global memory for C^N.
typedef GemmGlobalTileCdTraits<
// The pointer is float const.
typename GemmConfig_::ScalarC const,
// The tile has size (N / Iterations)xM in GEMM's terminology.
Shape<1,
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// How many elements do we jump over at each iteration?
Iterations::kW,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgC>
GlobalLoadTileTraits;
/// The iterator to load C.
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
/// The transformer for C.
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
/// The traits class to build the iterator to store data to global memory for D^N.
typedef GemmGlobalTileCdTraits<
// The pointer is float.
typename GemmConfig_::ScalarD,
// The tile has size (N / Iterations)xM in GEMM's terminology.
Shape<1,
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// How many elements do we jump over at each iteration?
Iterations::kW,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerStgD>
GlobalStoreTileTraits;
/// The iterator to store D.
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
/// The transformer for D.
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The GEMM config.
typename GemmConfig_,
/// The epilogue functor to do the math in the epilogue.
typename EpilogueFunctor_,
/// The index.
typename Index_ = int,
/// The helper to create the traits class.
typename Helper_ = GemmEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_> >
struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
// The output tile.
typename GemmConfig_::OutputTile,
// The accumulators.
typename GemmConfig_::Accumulators,
// The global iterator for C.
typename Helper_::GlobalLoadIteratorC,
// The transformer for C.
typename Helper_::GlobalTransformerC,
// The transformer for D.
typename Helper_::GlobalTransformerD,
// The global iterator for D.
typename Helper_::GlobalStoreIteratorD,
// The iterator to store D to shared memory.
typename Helper_::SharedStoreIteratorD,
// The shared store transformer for D.
typename Helper_::SharedStoreTransformerD,
// The stream to load D from shared memory.
typename Helper_::SharedLoadStreamD,
// The number of iterations.
typename Helper_::Iterations,
// The strides between iterations.
typename Helper_::Delta,
// The functor to be used in the epilogue.
EpilogueFunctor_,
// The index.
Index_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,255 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements efficient loading of the thread block-level tile from global memory and
storing
to shared memory.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/convert.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/tile_allocation.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Identifies multiplicand
GemmOperand::Kind Operand,
/// The load iterator.
typename LoadIterator_,
/// The store iterator to copy to shared memory.
typename StoreIterator_,
/// The transformer to be applied after the data has been copied from global memory.
typename Transformer_>
struct GlobalLoadStream {
/// Indicates the type of GEMM operand
static GemmOperand::Kind const kOperand = Operand;
/// The load iterator.
typedef LoadIterator_ LoadIterator;
/// The transformer.
typedef Transformer_ Transformer;
/// The store iterator to write to shared memory.
typedef StoreIterator_ StoreIterator;
/// The fragment that is copied from shared memory.
typedef typename LoadIterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Make sure the transformed fragment is the same as the store fragment.
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
"");
/// The layout.
static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
/// The scalar type of the iterator.
typedef typename LoadIterator::Scalar Scalar;
/// The pointer.
typedef typename LoadIterator::Pointer Pointer;
/// The index.
typedef typename LoadIterator::Index Index;
/// The index.
typedef typename LoadIterator::LongIndex LongIndex;
/// The tile
typedef typename LoadIterator::Tile Tile;
/// Shared memory allocation for the tile
typedef TileAllocation<typename StoreIterator::Scalar, typename StoreIterator::Tile>
ThreadblockTileStorage;
/// Tensor reference to threadblock tile
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
/// The params.
struct Params {
// The load iterator.
typename LoadIterator::Params load_iterator;
/// Batch stride in global memory
LongIndex batch_stride;
// The store iterator.
typename StoreIterator::Params store_iterator;
// Offset to residue.
Index offset_to_residue;
// Offset to residue for the last partition
Index offset_to_residue_last_partition;
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
LongIndex batch_stride_,
Index ldm,
Index offset_to_residue_,
Index offset_to_residue_last_partition_) {
int error_code = load_iterator.initialize(pointer, ldm, ldm);
if (error_code) {
return error_code;
}
batch_stride = batch_stride_;
offset_to_residue = offset_to_residue_;
offset_to_residue_last_partition = offset_to_residue_last_partition_;
return store_iterator.initialize();
}
CUTLASS_DEVICE Index get_offset_to_residue() {
if (blockIdx.z == gridDim.z - 1) { //last partition
return offset_to_residue_last_partition;
}
else {
return offset_to_residue;
}
}
};
/// Contains private storage in shared memory needed by the objects within this class. Note,
/// this is *NOT* the shared memory allocation for the GEMM threadblock tile. That necessarily
/// exists outside this class, as it is also needed by the warp-level shared=>RF stream.
struct SharedStorage {};
//
// Static member functions
//
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
CUTLASS_HOST_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
bool const kKstrided =
GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
Coord<3> tile_coord = ProjectOperand<kOperand, kKstrided>::project(coord);
return make_Coord(
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
}
/// Ctor.
CUTLASS_DEVICE GlobalLoadStream(
Params const& _params,
SharedStorage& shared_storage,
ThreadblockTileRef const& threadblock_tile_ref,
Coord<3> const bounds,
Coord<3> const& _threadblock_offset)
: params(_params),
threadblock_offset(project_coordinate(_threadblock_offset)),
multiplicand_bounds(project_coordinate(bounds, 1)),
load_iterator(params.load_iterator, threadblock_offset),
transformer(),
store_iterator(params.store_iterator, threadblock_tile_ref.data()) {
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
fetched_fragment.clear();
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy() {
load_iterator.load_post_increment(fetched_fragment);
}
/// Commit the data.
CUTLASS_DEVICE void commit() {
transformer.transform(fetched_fragment, transformed_fragment);
store_iterator.store_post_increment(transformed_fragment);
store_iterator.inc_stage();
}
/// Execute the residue code.
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
load_iterator.residue(k);
if (!skip_clear) {
fetched_fragment.clear();
}
}
/// Move to the residue portion.
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
Index kResidue = k % kTileK;
if (kResidue) {
residue(kResidue);
Index this_offset_residue = params.get_offset_to_residue();
load_iterator.add_pointer_offset(this_offset_residue * load_iterator.stride_advance());
}
}
/// Rollback to the beginning of the first tile
CUTLASS_DEVICE void rollback(void) {
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
int const kBlock = kOperand == GemmOperand::kA
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
Index this_offset_residue = params.get_offset_to_residue();
load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
load_iterator.stride_advance());
}
/// Adds a Coord<3> to the underlying global load iterator
CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
load_iterator += offset;
return *this;
}
/// Adds an offset based on batch stride
CUTLASS_DEVICE GlobalLoadStream &add_batch_offset(int batch_id) {
load_iterator.add_pointer_offset(batch_id * params.batch_stride);
return *this;
}
//
// Data members
//
/// Parameters
Params params;
/// Threadblock offset
Coord<3> threadblock_offset;
/// Multiplicand bounds
Coord<3> multiplicand_bounds;
/// The iterator.
LoadIterator load_iterator;
/// The fragment to fetch from shared memory.
FetchedFragment fetched_fragment;
/// The transformer.
Transformer transformer;
/// The fragment to convert the data after it has been fetched from shared memory.
TransformedFragment transformed_fragment;
/// The store iterator.
StoreIterator store_iterator;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,614 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators for efficiently loading and storing to global memory.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/util/platform.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functor reshapes a tile of threads to match a tile of data. The idea is that when
// the user wants to build the iterator traits, he/she may want to specify the tile independently
// from the number of scalars loaded/stored per instruction. For example, in the row-major version
// with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
// each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
// threads has to change. The code below detects that and correct the code automatically - it is
// a helper when the user does not specify the right configuration.
template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
struct ReshapeThreads {
typedef Threads_ Threads;
};
template <typename Tile_, typename Threads_>
struct ReshapeThreads<Tile_, Threads_, true> {
typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
struct GemmGlobalTileTraits {
/// Identity of the operand
static GemmOperand::Kind const kOperand = kOperand_;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The number of scalars per LDG/STG.
static int const kAccessSize = kAccessSize_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
/// The tile shape
typedef Tile_ Tile;
/// The vectorized tile shape
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile VectorizedTile;
/// The threads shape
typedef typename ReshapeThreads<VectorizedTile, Threads_>::Threads Threads;
/// The relative offset between two elements in the H/W dimension in adjacent threads.
typedef Shape<1, 1, VectorizedTile::kC> ThreadsDelta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
/// Strides for immediate offset computation
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
/// The number of iterations needed to load/store the tile.
typedef Shape<1,
VectorizedTile::kH / Threads::kH,
VectorizedTile::kW / Threads::kW,
VectorizedTile::kC / kAccessSize>
Iterations;
typedef GemmMultiplicandTraits<Tile, kOperand, kLayout> MultiplicandTraits;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
MatrixLayout::kColumnMajor,
Scalar_,
Tile_,
Threads_,
kAccessSize_> {
/// The base class.
typedef GemmGlobalTileTraits<GemmOperand::kC,
MatrixLayout::kColumnMajor,
Scalar_,
Tile_,
Threads_,
kAccessSize_>
Base;
/// The stride in the H dimension.
static int const kStrideH = kStrideH_;
/// Override the strides in each dimension between different loads/stores.
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
typedef typename Base::Iterations Iterations;
typedef typename Base::Threads Threads;
typedef typename Base::ThreadsDelta ThreadsDelta;
typedef typename Base::ImmediateOffsetStrides ImmediateOffsetStrides;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits_, typename Index_ = int>
struct GemmGlobalIteratorAb
: public TileLoadIterator<TileTraits_,
typename TileTraits_::Scalar,
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
: IteratorAdvance::kW,
MemorySpace::kGlobal,
Index_> {
/// This class.
typedef GemmGlobalIteratorAb<TileTraits_, Index_> This_; /// The base class.
typedef TileLoadIterator<TileTraits_,
typename TileTraits_::Scalar,
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
: IteratorAdvance::kW,
MemorySpace::kGlobal,
Index_>
Base;
/// The layout.
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
/// The tile
typedef typename TileTraits_::Tile Tile;
/// Fragment type loaded by the iterator
typedef typename Base::Fragment Fragment;
/// The scalar.
typedef typename TileTraits_::Scalar Scalar;
/// The threads.
typedef typename TileTraits_::Threads Threads;
/// The index.
typedef Index_ Index;
/// Long index
typedef long long LongIndex;
/// The thread offset
typedef typename TileTraits_::ThreadOffset ThreadOffset;
/// Specifies in which dimension post-increment accesses advance.
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
typedef cutlass::PredicateVector<ShapeCount<typename Base::Iterations>::kCount> PredicateVector;
/// Iterator parameters type
typedef typename Base::Params BaseParams;
struct Params : public BaseParams {
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr,
Index stride_d,
Index stride_h) {
return BaseParams::initialize(ptr, stride_d, stride_h, kAdvance == IteratorAdvance::kH ? 0 : 1);
}
};
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// The parameters
Params params;
/// The predicates.
PredicateVector predicates;
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
// Setup the masks to control loads.
predicates.fill(0);
// Fill in the bits of the predicate vector.
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
if (kAdvance == IteratorAdvance::kH) {
flag =
flag &&
(h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
bounds[1];
} else {
flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
}
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
predicates.set(bit, flag);
}
}
}
}
}
/// Ctor.
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params,
const Coord<3>& threadblock_offset,
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params) {
thread_offset = thread_offset_func();
// Setup the pointer.
params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
(threadblock_offset[2] + thread_offset[2]));
}
/// Increment the pointer in the W dimension.
CUTLASS_HOST_DEVICE void inc_w() { Base::inc_w(); }
/// Increment the pointer in the H dimension.
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
/// Increment the pointer in the D dimension.
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
/// Increment the pointer to move to the next iteration.
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
/// Loads a single fragment element from memory
CUTLASS_HOST_DEVICE void load_element(
typename Base::AccessType& value, int d, int h, int w, int c) const {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
Load<Scalar,
Base::kAccessSize,
Base::kMemorySpace,
Base::kFragmentElementType,
typename Base::FragmentElement,
Base::Tile::kW,
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
}
/// That's the residue! Update the predicates.
CUTLASS_HOST_DEVICE void residue(Index k) {
// Update the predicate vector.
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
Index offset = 0;
if (kAdvance == IteratorAdvance::kH) {
offset += thread_offset[1] + h * Base::Delta::kH + d * Base::Delta::kD;
} else {
offset += thread_offset[2] + w * Base::Delta::kW;
}
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
if (offset >= k) {
predicates.set(bit, false);
}
}
}
}
}
}
/// Is the valid?
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
return predicates[bit];
}
/// Adds a vector offset to the iterator
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) {
LongIndex _offset = offset.template dot<LongIndex>(
make_Coord(params.stride_d, params.stride_h, params.stride_w)
);
params.pointer += _offset;
return *this;
}
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
CUTLASS_HOST_DEVICE Index stride_advance(void) {
Index stride = params.stride_h;
if (kAdvance == IteratorAdvance::kW) {
stride = params.stride_w;
}
return stride;
}
template <typename Fragment>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
typename Base::FragmentIterator frag_iterator(fragment);
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
if (valid(d, h, w, c)) {
load_element(
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < Base::Iterations::kW - 1) {
inc_w();
}
}
if (h < Base::Iterations::kH - 1) {
inc_h();
}
}
if (d < Base::Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits_, typename Index_ = int>
struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
typename TileTraits_::Scalar,
IteratorAdvance::kH,
MemorySpace::kGlobal,
Index_> {
/// This class.
typedef GemmGlobalIteratorCd<TileTraits_, Index_> This_;
/// The base class.
typedef TileIteratorBase<TileTraits_,
typename TileTraits_::Scalar,
IteratorAdvance::kH,
MemorySpace::kGlobal,
Index_>
Base;
/// The layout.
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
/// The scalar.
typedef typename TileTraits_::Scalar Scalar;
/// The pointer.
typedef typename TileTraits_::Pointer Pointer;
/// The threads.
typedef typename TileTraits_::Threads Threads;
/// The index.
typedef Index_ Index;
/// The index.
typedef long long LongIndex;
/// The thread offset
typedef typename TileTraits_::ThreadOffset ThreadOffset;
/// The params.
struct Params {
/// The pointer.
Pointer pointer;
/// The stride in the D dimension
long long stride_d;
/// The stride in the H dimension to setup the thread in the block.
Index stride_h;
/// The strides to increment the pointer.
Index inc_advance, inc_h;
/// The strides to increment the predicate offset
Index predicate_inc_advance, predicate_inc_h;
/// The column offset to compute the predicate for the columns.
Index predicate_offset;
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
int stride_d_,
Index ldm,
Index bound,
Index epilogue_stride_w,
Index epilogue_delta_w) {
// The pointer.
this->pointer = pointer;
// Stride per batch
stride_d = stride_d_;
// Each column of the matrix.
stride_h = TileTraits_::ThreadsDelta::kH * ldm;
// Each thread output 1 column per iteration. The stride between columns is given by the
// number of scalars that are loaded per LDS for B.
inc_h = ldm * TileTraits_::kStrideH;
inc_advance =
(ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
predicate_offset = bound;
predicate_inc_h = TileTraits_::kStrideH;
predicate_inc_advance =
-((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
return 0;
}
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h,
Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h,
Index _predicate_offset) {
this->pointer = pointer;
stride_d = _stride_d;
stride_h = _stride_h;
inc_advance = _inc_advance;
inc_h = _inc_h;
predicate_inc_advance = _predicate_inc_advance;
predicate_inc_h = _predicate_inc_h;
predicate_offset = _predicate_offset;
return 0;
}
};
/// Parameters.
Params params;
/// Offset of an individual lane from the start of the tile
Coord<4> thread_offset;
/// The predicates for the row.
cutlass::PredicateVector<Base::Iterations::kW> predicates;
/// Ctor.
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
const Coord<3>& bounds,
const Coord<3>& block,
int offset = 0,
int pred_offset = 0,
ThreadOffset thread_offset_func = ThreadOffset())
: params(_params) {
thread_offset = thread_offset_func();
// Each warp works on a different column of the tile.
int const h = thread_offset[1] + block[1];
// Each lane writes a different element.
int const w = thread_offset[2] + block[2];
// Setup the pointer.
params.pointer += ((h * params.stride_h + w) + offset);
// Prepare the vector of predicates.
for (int i = 0; i < Base::Iterations::kW; ++i) {
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
}
params.predicate_offset -= (h + pred_offset);
}
/// Increment the pointer in the C dimension.
CUTLASS_HOST_DEVICE void inc_c() {}
/// Increment the pointer in the W dimension.
CUTLASS_HOST_DEVICE void inc_w() {}
/// Increment the pointer in the H dimension.
CUTLASS_HOST_DEVICE void inc_h() {
params.pointer += params.inc_h;
params.predicate_offset -= params.predicate_inc_h;
}
/// Increment the pointer in the D dimension.
CUTLASS_HOST_DEVICE void inc_d() {}
/// Increment the pointer to move to the next iteration.
CUTLASS_HOST_DEVICE void inc_advance() {
params.pointer += params.inc_advance;
params.predicate_offset -= params.predicate_inc_advance;
}
/// Adds a vector offset to the iterator
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) {
LongIndex _offset = offset.template dot<LongIndex>(
make_Coord(params.stride_d, params.stride_h, 1)
);
params.pointer += _offset;
return *this;
}
/// Loads a single fragment element from memory.
CUTLASS_HOST_DEVICE void load_element(
typename Base::AccessType& value, int d, int h, int w, int c) const {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Load<Scalar,
Base::kAccessSize,
Base::kMemorySpace,
Base::kFragmentElementType,
typename Base::FragmentElement,
Base::Tile::kW,
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
}
/// Stores a single fragment element into memory.
CUTLASS_HOST_DEVICE void store_element(
typename Base::AccessType const& value, int d, int h, int w, int c) {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
Store<Scalar,
Base::kAccessSize,
Base::kMemorySpace,
Base::kFragmentElementType,
typename Base::FragmentElement,
Base::Tile::kW,
Base::kAccessSize * sizeof(Scalar)>::store(value, params.pointer, offset);
}
/// Test the validity of the
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
return predicates.at(w) && params.predicate_offset > 0;
}
/// add pointer offset
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset) { params.pointer += offset; }
/// Loads and increments iterator
template <typename Fragment>
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
typename Base::FragmentIterator frag_iterator(fragment);
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
if (valid(d, h, w, c)) {
load_element(
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < Base::Iterations::kW - 1) {
inc_w();
}
}
if (h < Base::Iterations::kH - 1) {
inc_h();
}
}
if (d < Base::Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
template <typename Fragment>
CUTLASS_HOST_DEVICE void store_post_increment(Fragment& fragment) {
typename Base::FragmentIterator frag_iterator(fragment);
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
if (valid(d, h, w, c)) {
store_element(
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < Base::Iterations::kW - 1) {
inc_w();
}
}
if (h < Base::Iterations::kH - 1) {
inc_h();
}
}
if (d < Base::Iterations::kD - 1) {
inc_d();
}
}
inc_advance();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,141 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear
memory.
*/
#pragma once
#include "cutlass/matrix_traits.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/util/platform.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to describe attributes of GEMM matrix operands
template <GemmOperand::Kind kOperand_, MatrixLayout::Kind kLayout_>
struct GemmOperandTraitsAb {
static const bool Congruous =
(kOperand_ == GemmOperand::kA ^ kLayout_ == MatrixLayout::kRowMajor);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmOperand::Kind kOperand_, typename Tile_>
struct GetExtent;
template <typename Tile_>
struct GetExtent<GemmOperand::kA, Tile_> {
static const int kExtent = Tile_::kW;
};
template <typename Tile_>
struct GetExtent<GemmOperand::kB, Tile_> {
static const int kExtent = Tile_::kH;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Determines the shape of a multiplicand tile in terms of strided (H) and contiguous (W)
/// dimensions
template <typename ThreadBlockTile_, GemmOperand::Kind Usage, MatrixLayout::Kind Layout>
struct GemmMultiplicandTraits {
// Only defined for A or B
static_assert(Usage == GemmOperand::kA || Usage == GemmOperand::kB,
"MultiplicandTileShape defined only for A or B operands.");
/// Shape of GEMM thread block tile (K, N, M)
typedef ThreadBlockTile_ ThreadBlockTile;
/// Identifies multiplicand
static GemmOperand::Kind const kUsage = Usage;
/// Layout of tile
static MatrixLayout::Kind const kLayout = Layout;
// True if K is the strided dimension
static bool const kKstrided = (kUsage == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor);
/// Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand
typedef typename platform::conditional<
kKstrided,
Shape<1, ThreadBlockTile::kD, GetExtent<Usage, ThreadBlockTile>::kExtent>,
Shape<1, GetExtent<Usage, ThreadBlockTile>::kExtent, ThreadBlockTile::kD> >::type Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Project's a coordinate (K, N, M) onto inner and outer dimensions defined for each
/// operand.
template <GemmOperand::Kind operand, bool Kstrided = true>
struct ProjectOperand;
/// Project A operand - (0, K, M)
template <bool Kstrided>
struct ProjectOperand<GemmOperand::kA, Kstrided> {
CUTLASS_HOST_DEVICE
static Coord<3> project(Coord<3> const &coord) {
if (Kstrided) {
return make_Coord(0, coord[0], coord[2]);
} else {
return make_Coord(0, coord[2], coord[0]);
}
}
};
/// Project B operand - (0, K, N)
template <bool Kstrided>
struct ProjectOperand<GemmOperand::kB, Kstrided> {
CUTLASS_HOST_DEVICE
static Coord<3> project(Coord<3> const &coord) {
if (Kstrided) {
return make_Coord(0, coord[0], coord[1]);
} else {
return make_Coord(0, coord[1], coord[0]);
}
}
};
/// Project C operand - (0, N, M)
template <>
struct ProjectOperand<GemmOperand::kC, true> {
CUTLASS_HOST_DEVICE
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
};
/// Project D operand - (0, N, M)
template <>
struct ProjectOperand<GemmOperand::kD, true> {
CUTLASS_HOST_DEVICE
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,134 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines abstractions for managing loading and storing fragments to shared memory in the
efficient GEMM pipeline.
*/
#pragma once
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/gemm_shared_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The load iterator.
typename Iterator_,
/// The transformer to be applied after the data has been copied from shared memory.
typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct SharedLoadStream {
/// The load iterator.
typedef Iterator_ Iterator;
/// The transformer.
typedef Transformer_ Transformer;
/// The fragment that is copied from shared memory.
typedef typename Iterator::Fragment FetchedFragment;
/// The fragment that is obtained after the transformation by the transformer.
typedef typename Transformer::OutputFragment TransformedFragment;
/// Make sure the fragments match.
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
"");
/// The output fragment.
typedef TransformedFragment Fragment;
/// Scalar data type
typedef typename Iterator::Scalar Scalar;
/// Reference type to a tensor
typedef TensorRef<Scalar, 4> TensorRef;
/// The params.
struct Params {
/// The iterator params.
typename Iterator::Params iterator;
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize() { return iterator.initialize(); }
};
/// The storage in shared memory needed by that stream.
typedef typename Iterator::Storage SharedStorage;
/// Ctor.
CUTLASS_DEVICE SharedLoadStream() {}
/// Ctor.
CUTLASS_DEVICE SharedLoadStream(Params const &params, TensorRef const &ref) {
this->initialize(params, ref);
}
/// Initialize the stream.
CUTLASS_DEVICE void initialize(Params const &params, TensorRef const &ref) {
// The iterator.
iterator = Iterator(params.iterator, ref.data());
// The transformer.
transformer = Transformer();
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy() {
iterator.load_post_increment(fetched[0]);
}
/// Load the data from shared memory to the fetch fragment.
CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); }
/// Commit the data.
CUTLASS_DEVICE void commit() { transformer.transform(fetched[0], transformed[0]); }
/// Commit the data.
CUTLASS_DEVICE void commit(int step) {
transformer.transform(fetched[step % 2], transformed[step % 2]);
}
/// Returns the fragment for the given step
CUTLASS_DEVICE TransformedFragment &fragment(int step = 0) { return transformed[step % 2]; }
/// Returns the fragment for the given step
CUTLASS_DEVICE TransformedFragment const &fragment(int step = 0) const {
return transformed[step % 2];
}
/// Increment the stage.
CUTLASS_DEVICE void inc_stage() { iterator.inc_stage(); }
/// The iterator.
Iterator iterator;
/// Fetched fragment
FetchedFragment fetched[2];
/// The transformer.
Transformer transformer;
/// Transformed fragment
TransformedFragment transformed[2];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,417 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators for efficiently loading and storing tiles to and from shared memory.
*/
#pragma once
#include "cutlass/gemm/gemm_operand.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_>
struct GemmSharedStoreTileAbTraits {
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The tile.
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile Tile;
/// The threads.
typedef Threads_ Threads;
/// The strides to compute the base position of the thread.
typedef Shape<0, ShapeCount<Tile>::kWc, Tile::kC, kScalarsPerSts_> ThreadsStrides;
/// The skew.
static int const kSkew = 0;
/// The number of scalars per LDG/STG.
static int const kAccessSize = kScalarsPerSts_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of iterations needed to load/store the tile.
typedef Shape<1,
Tile::kH / Threads::kH,
Tile::kW / Threads::kW,
Tile::kC / Threads::kC / kAccessSize>
Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize>
ImmediateOffsetStrides;
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_, int kSkew_>
struct GemmSharedStoreWithSkewTileAbTraits {
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The tile without skews.
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile TileWithoutSkew;
/// The tile.
typedef typename ReshapeTile<Shape<Tile_::kD, Tile_::kH, Tile_::kW + kSkew_>,
kScalarsPerSts_>::Tile Tile;
/// The threads.
typedef Threads_ Threads;
/// The skew.
static int const kSkew = kSkew_;
/// The number of scalars per STS.
static int const kAccessSize = kScalarsPerSts_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of iterations needed to load/store the tile.
typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH> Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
struct ThreadOffset {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
return make_Coord(0, 0, offset, 0);
}
};
protected:
/// The strides to compute the base position of the thread.
typedef Shape<0, kScalarsPerSts_, ShapeCount<Tile>::kHwc / Threads::kW> ThreadsStrides;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
typename OutputTile_,
typename Warps_,
typename ThreadsPerWarp_,
typename InstructionShape_,
int kStages_,
int kScalarsPerLds_,
int kSkew_ = 0>
struct GemmSharedLoadTileATraits {
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The tile without skew.
typedef Shape<kStages_,
OutputTile_::kD / InstructionShape_::kD,
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
TileWithoutSkew_;
/// The tile with skew.
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
/// The tile without skew after reshaping.
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
/// The tile.
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
/// The number of warps.
typedef Warps_ Warps;
/// The threads in a warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of scalars per LDG/STG.
// static int const kScalarsPerLds = kScalarsPerLds_;
static int const kAccessSize = kScalarsPerLds_;
/// The skew.
static int const kSkew = kSkew_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of warps.
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
/// The number of threads in one dimension of the warp.
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
/// The number of iterations needed to load/store the tile.
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
ImmediateOffsetStrides;
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Extract the warp.
int const warp = threadIdx.x / kWarpSize;
// Extract the slice.
int const slice = warp / (Warps::kH * Warps::kW);
// Compute the row offset for each warp.
int const warp_row = warp % Warps::kW;
// Compute the row offset for each thread.
int const lane_row = (threadIdx.x & 0x0e) / 2;
// The offset.
int const offset =
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
// Embed the offset in a 4D coordinate vector.
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
typename OutputTile_,
typename Warps_,
typename ThreadsPerWarp_,
typename InstructionShape_,
int kStages_,
int kScalarsPerLds_,
int kSkew_ = 0>
struct GemmSharedLoadTileBTraits {
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The tile without skew.
typedef Shape<kStages_,
OutputTile_::kD / InstructionShape_::kD,
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
TileWithoutSkew_;
/// The tile with skew.
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
/// The tile without skew after reshaping.
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
/// The tile.
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
/// The number of warps.
typedef Warps_ Warps;
/// The threads in a warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of scalars per LDG/STG.
static int const kAccessSize = kScalarsPerLds_;
/// The skew.
static int const kSkew = kSkew_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of warps.
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
/// The number of threads in one dimension of the warp.
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
/// The number of iterations needed to load/store the tile.
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
ImmediateOffsetStrides;
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Extract the warp.
int const warp = threadIdx.x / kWarpSize;
// Extract the slice.
int const slice = warp / (Warps::kH * Warps::kW);
// The warp in the slice.
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
// Compute the row offset for each warp.
int const warp_col = warp_in_slice / Warps::kW;
// Compute the row offset for each thread.
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
// The offset.
int const offset =
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
// Embed the offset in a 4D coordinate.
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
typename OutputTile_,
typename Warps_,
typename ThreadsPerWarp_,
int kScalarsPerSts_,
int kSkew_ = 0>
struct GemmSharedStoreTileDTraits {
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The dimension of the output tile.
typedef OutputTile_ OutputTile;
/// The warps in the tile.
typedef Warps_ Warps;
/// The threads in the warps.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of scalars per LDG/STG.
static int const kAccessSize = kScalarsPerSts_;
/// The skew.
static int const kSkew = kSkew_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of scalars per thread.
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
/// The number of threads.
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
/// The tile.
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
/// The number of iterations needed to store the tile.
typedef Shape<1, 1, kScalarsPerThread / kAccessSize> Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> ImmediateOffsetStrides;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// The warp.
int const warp = threadIdx.x / kWarpSize;
// The position of the warp in the 2D tile.
int const warp_row = warp % Warps::kW;
int const warp_col = warp / Warps::kW;
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
// Odd threads go to the second half of shared memory.
int const row = threadIdx.x & 0x01;
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
// Embed the offset in a 4D coords.
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
typename OutputTile_,
typename Warps_,
typename ThreadsPerWarp_,
int kTileH_,
int kScalarsPerLds_,
int kSkew_ = 0>
struct GemmSharedLoadTileDTraits {
/// The scalar.
typedef typename platform::remove_const<Scalar_>::type Scalar;
/// The pointer.
typedef Scalar_* Pointer;
/// The dimension of the output tile.
typedef OutputTile_ OutputTile;
/// The warps in the tile.
typedef Warps_ Warps;
/// The threads in the warps.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of scalars per LDG/STG.
static int const kAccessSize = kScalarsPerLds_;
/// The skew.
static int const kSkew = kSkew_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The number of scalars per thread.
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
/// The number of threads.
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
/// conflicts in the epilogue.
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
// Compute the number of iterations per warp in the Tile::kH dimension.
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
// to keep the number of elements to reduce for split-K.
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
// threads to grab the other elements.
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
/// The number of iterations needed to store the tile.
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
ImmediateOffsetStrides;
/// The strides in each dimension between different loads/stores.
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
// Each warp works on a different column.
int const h = threadIdx.x / kWarpSize;
// Compute the row.
int const w = (threadIdx.x & (kWarpSize - 1)) * kAccessSize;
int offset = 0;
if (Iterations::kH == 1) {
int const row = h & 0x1;
int const col = h / 2;
offset = row * ShapeCount<Tile>::kWc + col * OutputTile::kW * Iterations::kD + w;
} else {
offset = h * OutputTile::kW * Iterations::kD + w;
}
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,258 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a pair of GEMM tile streams
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_allocation.h"
#include "cutlass/tile_iterator.h"
#include "cutlass/gemm/clear_accumulators.h"
#include "cutlass/gemm/gemm_config.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/threadblock_swizzle.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Collect the global load streams for multiplicands.
template <typename StreamA_, typename StreamB_, bool kResidueInProlog_>
struct GlobalLoadStreamPair {
//
// Type definitions
//
/// Stream for A multiplicand
typedef StreamA_ StreamA;
/// Stream for B multiplicand
typedef StreamB_ StreamB;
/// Parameters object
struct Params {
/// Parameters object for StreamA
typename StreamA::Params stream_a;
/// Parameters object for StreamB
typename StreamB::Params stream_b;
/// Default constructor
CUTLASS_HOST_DEVICE
Params() {}
/// Constructs a global load stream pair Params object
CUTLASS_HOST_DEVICE
Params(typename StreamA::Params const &_params_A, typename StreamB::Params const &_params_B)
: stream_a(_params_A), stream_b(_params_B) {}
};
/// Assumes the A stream defines the index type
typedef typename StreamA::Index Index;
/// Shared memory allocation for threadblock-scoped GEMM tile
typedef ZipTileAllocation<typename StreamA::ThreadblockTileStorage,
typename StreamB::ThreadblockTileStorage>
ThreadblockTileStorage;
/// ZipTensorRef to threadblock tiles
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
/// Defines a structure containing shared storage for each pair
struct SharedStorage {
typename StreamA::SharedStorage stream_a;
typename StreamB::SharedStorage stream_b;
};
//
// Data members
//
/// Stream for A multiplicand
StreamA stream_a;
/// Stream for B multiplicand
StreamB stream_b;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE GlobalLoadStreamPair(Params const &params,
SharedStorage &shared_storage,
ThreadblockTileRef const &threadblock_tile_ref,
Coord<3> const bounds,
Coord<3> const &block_offset = make_Coord(0, 0, 0))
: stream_a(params.stream_a,
shared_storage.stream_a,
threadblock_tile_ref.first,
bounds,
block_offset),
stream_b(params.stream_b,
shared_storage.stream_b,
threadblock_tile_ref.second,
bounds,
block_offset) {}
CUTLASS_DEVICE
GlobalLoadStreamPair & operator+=(Coord<3> const offset) {
stream_a += offset;
stream_b += offset;
return *this;
}
CUTLASS_DEVICE
GlobalLoadStreamPair & add_batch_offset(int batch_id) {
stream_a.add_batch_offset(batch_id);
stream_b.add_batch_offset(batch_id);
return *this;
}
/// Trigger the copies from shared memory to registers.
CUTLASS_DEVICE void copy() {
stream_a.copy();
stream_b.copy();
}
/// Commit the data.
CUTLASS_DEVICE void commit() {
stream_a.commit();
stream_b.commit();
}
/// Execute the residue code.
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
stream_a.residue(k, skip_clear);
stream_b.residue(k, skip_clear);
}
/// Move to residue.
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
if (kResidueInProlog_) {
stream_a.move_to_residue(k, kTileK);
stream_b.move_to_residue(k, kTileK);
} else if (k < kTileK) {
residue(k, true);
}
}
/// Rollback to beginning of first tile.
CUTLASS_DEVICE void rollback(bool kRollback) {
if (kResidueInProlog_ && kRollback) {
stream_a.rollback();
stream_b.rollback();
}
}
};
/// Collect the global load streams for multiplicands.
template <typename StreamA_, typename StreamB_>
struct SharedStreamPair {
//
// Type definitions
//
/// Stream for A multiplicand
typedef StreamA_ StreamA;
/// Stream for B multiplicand
typedef StreamB_ StreamB;
/// Parameters object passed to load iterators
struct Params {
///
typename StreamA::Params stream_a;
///
typename StreamB::Params stream_b;
};
/// Shared memory allocation for threadblock-scoped GEMM tile
typedef ZipTensorRef<typename StreamA::TensorRef,
typename StreamB::TensorRef >
ThreadblockTileRef;
//
// Data members
//
/// The stream for A.
StreamA stream_a;
/// The stream for B.
StreamB stream_b;
//
// Methods
//
/// Construct with the composable structure
CUTLASS_DEVICE SharedStreamPair(Params const &params, ThreadblockTileRef const &threadblock_tile_ref)
: stream_a(params.stream_a, threadblock_tile_ref.first),
stream_b(params.stream_b, threadblock_tile_ref.second) {}
/// Trigger the copies from shared memory to registers.
CUTLASS_DEVICE void copy(int step) {
stream_a.copy(step);
stream_b.copy(step);
}
/// Commit the data.
CUTLASS_DEVICE void commit(int step) {
stream_a.commit(step);
stream_b.commit(step);
}
/// The fragment A.
CUTLASS_DEVICE
typename StreamA::TransformedFragment const &fragment_a(int step) const {
return stream_a.fragment(step);
}
/// The fragment B.
CUTLASS_DEVICE
typename StreamB::TransformedFragment const &fragment_b(int step) const {
return stream_b.fragment(step);
}
/// Increment the stage.
CUTLASS_DEVICE void inc_stage() {
stream_a.inc_stage();
stream_b.inc_stage();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,797 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties of complete GEMM computation.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_allocation.h"
#include "cutlass/tile_iterator.h"
#include "cutlass/kernel_launch.h"
#include "cutlass/gemm/clear_accumulators.h"
#include "cutlass/gemm/gemm_config.h"
#include "cutlass/gemm/gemm_desc.h"
#include "cutlass/gemm/gemm_stream_pair.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/threadblock_swizzle.h"
#include "cutlass/gemm/gemm.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind, typename GemmConfig_>
struct GemmTileTraitsHelperA {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// The input scalar.
typedef typename GemmConfig_::ScalarA Scalar;
/// The scalar stored in shared memory.
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
/// The traits class to build the iterator to load data from global memory for A^N.
typedef GemmGlobalTileTraits<
// That's A.
GemmOperand::kA,
// A is column-major.
MatrixLayout::kColumnMajor,
// The pointer is float const.
Scalar const,
// The tile has size KxM in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
/// The traits class to build the iterator to store data to shared memory for A^N.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxM in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
GemmConfig_::kScalarsPerStsA>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^N.
typedef GemmSharedLoadTileATraits<
// The pointer is float const.
MultiplyAddScalar const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsA,
// The skew.
0>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// The input scalar.
typedef typename GemmConfig_::ScalarA Scalar;
/// The scalar stored in shared memory.
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
/// The traits class to build the iterator to load data from global memory for A^T.
typedef GemmGlobalTileTraits<
// That's A.
GemmOperand::kA,
// A is row-major.
MatrixLayout::kRowMajor,
// The pointer is float const.
Scalar const,
// The tile has size MxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
/// The number of scalars in 4B.
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
/// The skew for A.
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
GlobalTileTraits::Threads::kW * kScalarsIn4B;
/// The traits class to build the iterator to store data to shared memory for A^T.
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxM in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
GemmConfig_::kScalarsPerStsA,
// The skew to avoid bank conflicts added in the tile W dimension.
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^T.
typedef GemmSharedLoadTileATraits<
// The pointer is float const.
MultiplyAddScalar const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsA,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind, typename GemmConfig_>
struct GemmTileTraitsHelperB {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// The input scalar.
typedef typename GemmConfig_::ScalarB Scalar;
/// The scalar stored in shared memory.
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
/// The traits class to build the iterator to load data from global memory for B^N.
typedef GemmGlobalTileTraits<
// That's B.
GemmOperand::kB,
// B is column-major.
MatrixLayout::kColumnMajor,
// The pointer is float const.
Scalar const,
// The tile has size MxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
/// The number of scalars in 4B.
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
/// The skew for B.
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
GlobalTileTraits::Threads::kW * kScalarsIn4B;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
GemmConfig_::kScalarsPerStsB,
// The skew to avoid bank conflicts added in the tile W dimension.
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
// The pointer is float const.
MultiplyAddScalar const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsB,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// The input scalar.
typedef typename GemmConfig_::ScalarB Scalar;
/// The scalar stored in shared memory.
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
/// The traits class to build the iterator to load data from global memory for B^T.
typedef GemmGlobalTileTraits<
// That's B.
GemmOperand::kB,
// B is row-major.
MatrixLayout::kRowMajor,
// The pointer is float const.
Scalar const,
// The tile has size KxN in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
/// The traits class to build the iterator to store data to shared memory for B^T.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
MultiplyAddScalar,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
GemmConfig_::kScalarsPerStsB>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^T.
typedef GemmSharedLoadTileBTraits<
// The pointer is float const.
MultiplyAddScalar const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsB,
// The skew.
0>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The GEMM configuration.
typename GemmConfig_,
/// The stream to load A from global memory to shared memory.
typename GlobalLoadStreamA_,
/// The stream to load B from global memory to shared memory.
typename GlobalLoadStreamB_,
/// The stream to load A from shared memory.
typename SharedLoadStreamA_,
/// The stream to load B from shared memory.
typename SharedLoadStreamB_,
/// The epilogue.
typename Epilogue_,
/// The block swizzle to reorganize the grid.
typename BlockSwizzle_ = IdentityBlockSwizzle,
/// The index.
typename Index_ = int,
/// The tool used to clear accumulators.
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Element> >
struct GemmTraits {
/// This traits
typedef GemmTraits<GemmConfig_,
GlobalLoadStreamA_,
GlobalLoadStreamB_,
SharedLoadStreamA_,
SharedLoadStreamB_,
Epilogue_,
BlockSwizzle_,
Index_,
ClearAccumulators_> This_;
/// The struct that consumes this Traits
typedef typename cutlass::gemm::Gemm<This_> KernelClass;
/// The configuration.
typedef GemmConfig_ GemmConfig;
/// The output tile.
typedef typename GemmConfig::OutputTile OutputTile;
/// The stream to load A from global memory to shared memory.
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
/// The layout of A.
static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
/// The scalar for A.
typedef typename GlobalLoadStreamA_::Scalar ScalarA;
/// The stream to load B from global memory to shared memory.
typedef GlobalLoadStreamB_ GlobalLoadStreamB;
/// The layout of B.
static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
/// The scalar for B.
typedef typename GlobalLoadStreamB_::Scalar ScalarB;
/// The iterator for A to load from shared memory.
typedef SharedLoadStreamA_ SharedLoadStreamA;
/// The iterator for B to load from shared memory.
typedef SharedLoadStreamB_ SharedLoadStreamB;
/// The multiply-add functor.
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
/// The epilogue.
typedef Epilogue_ Epilogue;
/// The scalars in the epilogue.
typedef typename Epilogue::ScalarC ScalarC;
typedef typename Epilogue::ScalarD ScalarD;
/// The block swizzle to reorganize the grid.
typedef BlockSwizzle_ BlockSwizzle;
/// The index.
typedef Index_ Index;
/// Clear the accumulators.
typedef ClearAccumulators_ ClearAccumulators;
/// Assemble the global load streams for A/B.
typedef GlobalLoadStreamPair<GlobalLoadStreamA,
GlobalLoadStreamB,
GemmConfig::kResidueInProlog>
GlobalLoadStream;
/// Memory needed to store the threadblock-scoped GEMM tile
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
/// Assemble the shared load streams for A/B.
typedef SharedStreamPair<SharedLoadStreamA, SharedLoadStreamB> SharedStream;
/// Parameters object constructable on the host.
struct Params : public KernelLaunchConfiguration {
/// GEMM problem size
GemmCoord problem_size;
/// The K range for every partition except the last one
int partitionK_range;
/// Parameters object for the global load stream
typename GlobalLoadStream::Params global_to_shared_stream;
/// Parameters object for the shared load stream
typename SharedStream::Params shared_stream;
/// The params for the epilogue.
typename Epilogue::Params epilogue;
/// Initialize the parameters.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
// Set the problem size.
problem_size = desc.problem_size;
// there is no partitionK in the default case
partitionK_range = problem_size[0];
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Compute offset to residue.
// partitionK_range <= problem_size[0]
Index gemm_k = problem_size[0];
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
// Initialize parameters objects for
int error_code = global_to_shared_stream.stream_a.initialize(
desc.A.data(),
desc.batch_stride_A,
desc.A.leading_dim(),
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
error_code = global_to_shared_stream.stream_b.initialize(
desc.B.data(),
desc.batch_stride_B,
desc.B.leading_dim(),
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
// The epilogue.
return epilogue.initialize(desc);
}
/// Helper to construct a GEMM params using a BLAS-like API
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
ScalarB const* d_b,
Index ldb,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
ScalarD* d_d,
Index ldd) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, 1),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
TensorRef<ScalarB const, 2>(d_b, ldb),
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
TensorRef<ScalarD, 2>(d_d, ldd)
);
return this->initialize(desc);
}
/// Helper to construct a batched GEMM params
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
long long int batch_stride_A,
ScalarB const* d_b,
Index ldb,
long long int batch_stride_B,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
long long int batch_stride_C,
ScalarD* d_d,
Index ldd,
long long int batch_stride_D,
Index batch_count) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, batch_count),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
batch_stride_A,
TensorRef<ScalarB const, 2>(d_b, ldb),
batch_stride_B,
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
batch_stride_C,
TensorRef<ScalarD, 2>(d_d, ldd),
batch_stride_D
);
return this->initialize(desc);
}
/// Helper to construct a partitionedK GEMM params
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
// partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
// the problem_size of each batch is (lastK_size, n, m)
// add more comments here
// the k range for every batch excpet the last one
//assert(partitionK_count_ > 0);
partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
// the k range of the last batch
// int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
int k_size = lastK_range;
int lda = partitonK_desc.A.stride(0);
int ldb = partitonK_desc.B.stride(0);
int ldc = partitonK_desc.C.stride(0);
int ldd = partitonK_desc.D.stride(0);
int n = partitonK_desc.problem_size.n();
long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
long long int batch_stride_C = ldc * n;
long long int batch_stride_D = ldd * n;
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
//we pass lastK_size as per batch K. there is also a range that will match partitionK_size
GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
partitonK_desc.alpha,
partitonK_desc.A,
batch_stride_A,
partitonK_desc.B,
batch_stride_B,
partitonK_desc.beta,
partitonK_desc.C,
batch_stride_C,
partitonK_desc.D,
batch_stride_D
);
// Set the problem size.
problem_size = desc.problem_size;
// Compute grid dimensions
BlockSwizzle block_swizzle;
this->block = dim3(GemmConfig::kThreads);
this->grid = block_swizzle.get_grid_layout(
problem_size,
make_Coord_from_shape<OutputTile>());
// Compute offset to residue.
// partitionK_range <= problem_size[0]
Index gemm_k = problem_size[0];
Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
// Initialize parameters objects for
int error_code = global_to_shared_stream.stream_a.initialize(
desc.A.data(),
desc.batch_stride_A,
desc.A.leading_dim(),
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
error_code = global_to_shared_stream.stream_b.initialize(
desc.B.data(),
desc.batch_stride_B,
desc.B.leading_dim(),
offset_to_residue,
offset_to_residue_last_partition
);
if (error_code) {
return error_code;
}
// The epilogue.
return epilogue.initialize(desc);
}
/// Helper to construct a partitionedK GEMM params
CUTLASS_HOST_DEVICE int initialize(Index m,
Index n,
Index k,
typename Epilogue::Scalar alpha,
ScalarA const* d_a,
Index lda,
ScalarB const* d_b,
Index ldb,
typename Epilogue::Scalar beta,
ScalarC const* d_c,
Index ldc,
ScalarD* d_d,
Index ldd,
Index partitionK_count_) {
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
GemmCoord(k, n, m, 1),
alpha,
TensorRef<ScalarA const, 2>(d_a, lda),
TensorRef<ScalarB const, 2>(d_b, ldb),
beta,
TensorRef<ScalarC const, 2>(d_c, ldc),
TensorRef<ScalarD, 2>(d_d, ldd)
);
return this->initialize(desc, partitionK_count_);
}
};
// The storage for the main loop + prologue.
struct MainLoopSharedStorage {
/// Stores the threadblock tile
ThreadblockTileStorage threadblock_tile;
/// Storage for GEMM global stream
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
/// Storage for clearing accumulators
typename ClearAccumulators::SharedStorage clear;
};
/// The storage in shared memory.
union SharedStorage {
// The storage for the main loop.
MainLoopSharedStorage main_loop;
// The storage for the epilogue.
typename Epilogue::SharedStorage epilogue;
};
/// The memory fence for shared loads.
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
SharedLoadStreamB::Iterator::kRequiresLoadFence) {
__syncthreads();
}
}
/// The memory fence for shared stores.
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
__syncthreads();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
struct SimplifiedGemmTraitsHelper {
/// The global iterator to load A from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_>
GlobalLoadIteratorA;
/// The data converter for A before storing to shared memory.
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
/// The iterator to store A to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorA;
/// The stream to load A from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kA,
GlobalLoadIteratorA,
SharedStoreIteratorA,
GlobalTransformerA>
GlobalLoadStreamA;
/// The global iterator to load B from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_>
GlobalLoadIteratorB;
/// The data converter for B before storing to shared memory.
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
/// The iterator to store B to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorB;
/// The stream to load B from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kB,
GlobalLoadIteratorB,
SharedStoreIteratorB,
GlobalTransformerB>
GlobalLoadStreamB;
/// The iterator to load A from shared memory.
typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
typename GemmTileTraitsHelperA_::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorA;
/// The stream to load A from shared memory.
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
/// The iterator to load B from shared memory.
typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
typename GemmTileTraitsHelperB_::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorB;
/// The stream to load B from shared memory.
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The config for the GEMM.
typename GemmConfig_,
/// The epilogue.
typename Epilogue_,
/// The index.
typename Index_ = int,
// The configuration for the A matrix.
typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
// The configuration for the B matrix.
typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
// The helper class to create the streams and iterators.
typename Helper_ =
SimplifiedGemmTraitsHelper<GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_> >
struct SimplifiedGemmTraits : public GemmTraits<
// The config.
GemmConfig_,
// The stream to load A from global memory to shared memory.
typename Helper_::GlobalLoadStreamA,
// The stream to load B from global memory to shared memory.
typename Helper_::GlobalLoadStreamB,
// The stream to load A from shared memory.
typename Helper_::SharedLoadStreamA,
// The stream to load B from shared memory.
typename Helper_::SharedLoadStreamB,
// The epilogue.
Epilogue_,
// The block swizzle to reorganize the grid.
IdentityBlockSwizzle,
// The index.
Index_,
// The tool used to clear accumulators.
ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,90 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tile traits used to construct global tile iterator for HGEMM. This is intended to
partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate
memory accesses larger than 16 bits.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/reshape_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits<
// Which GEMM operand?
kOperand_,
// The layout.
kLayout_,
// The scalar.
Scalar_,
// The tile.
Tile_,
// The threads.
Threads_,
// The number of scalars per LDG/STG.
kAccessSize_> {
/// The base class.
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
/// The threads.
typedef typename Base::Threads Threads;
/// The threads strides.
typedef Shape<1, 2, Base::VectorizedTile::kC> ThreadsDelta;
/// The strides in each dimension between different loads/stores.
typedef Shape<Base::Threads::kH * 2, 1, Base::Threads::kW, Base::kAccessSize> Delta;
/// The number of iterations needed to load/store the tile.
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 2,
2,
Base::VectorizedTile::kW / Base::Threads::kW,
Base::VectorizedTile::kC / Base::kAccessSize>
Iterations;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,106 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Specialization implementing multiply-add operation on half-precision floating point
fragments.
*/
#pragma once
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Template performing matrix multiply-add operation within a thread
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
/// The shape of the instruction.
typedef Shape<1, 1, 2, 1> InstructionShape;
/// The number of accumulators per thread.
typedef ThreadGemmShape_ ThreadGemmShape;
/// Aliased for compatibility. Will be removed for CUTLASS v2.0.
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of accumulators per warp.
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
/// The type for A.
typedef half ScalarA;
/// The fragment for A.
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
/// The type for B.
typedef half ScalarB;
/// The fragment for B.
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
/// The type for C and D.
typedef half ScalarC;
/// The accumulators.
typedef Fragment<half, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
/// Make sure there's an even number of elements in both dimensions.
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
/// Multiply : d = a*b + c.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
// The inputs.
__half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
// The offsets in the output fragment.
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
// Compute the product a[i] * b[j].low.
d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
// Compute the product a[i] * b[j].high.
d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
}
}
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,94 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in
shared memory for multiplicands.
*/
#pragma once
#include <cuda_fp16.h>
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GlobalIterator_>
struct HgemmSwizzle {
/// The global iterator.
typedef GlobalIterator_ GlobalIterator;
/// The source fragment.
typedef typename GlobalIterator::Fragment Fragment;
/// The shape of the source fragment.
typedef typename GlobalIterator::FragmentShape FragmentShape;
/// The input fragment.
typedef Fragment InputFragment;
/// The output fragment.
typedef Fragment OutputFragment;
/// The src/dst must be half fragments.
static_assert((platform::is_same<typename Fragment::Element, half>::value), "Works on half");
/// The number of elements must be a multiple of 2.
static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
/// Ctor.
CUTLASS_DEVICE HgemmSwizzle() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast<int const*>(&src[0]);
int* dst_int = reinterpret_cast<int*>(&dst[0]);
// Transpose the data.
for (int d = 0; d < FragmentShape::kD; ++d) {
// The indices to read two consecutive "rows".
int const i0 = 2 * d + 0;
int const i1 = 2 * d + 1;
int a0 = src_int[i0];
int a1 = src_int[i1];
int b0, b1;
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
// The indices to store with "strides".
int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
dst_int[j0] = b0;
dst_int[j1] = b1;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,406 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies structural properties of half-precision GEMM computation.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_epilogue.h"
#include "cutlass/gemm/gemm_epilogue_traits.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/gemm/gemm_shared_tile.h"
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/hgemm_global_tile.h"
#include "cutlass/gemm/hgemm_multiply_add.h"
#include "cutlass/gemm/hgemm_swizzle.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_ = 2,
/// The number of scalars per LDG for B.
int kScalarsPerLdgB_ = 2>
struct HgemmConfig : public GemmConfig<
/// The scalar type for A.
half,
/// The scalar type for B.
half,
/// The scalar type for C.
half,
/// The scalar type for D.
half,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, half, half, half>,
/// The number of scalars per LDG for A.
kScalarsPerLdgA_,
/// The number of scalars per STS for A.
kScalarsPerLdgA_,
/// The number of scalars per LDS for A.
8,
/// The number of scalars per LDG for B.
kScalarsPerLdgB_,
/// The number of scalars per STS for B.
kScalarsPerLdgB_,
/// The number of scalars per LDS for B.
8,
/// The number of scalars per LDG for C and STG for D.
2,
/// The number of scalars per STS for D.
8,
/// The number of scalars per LDS for D.
2,
/// The number of stages in shared memory.
2,
/// kResidueSeparate
false,
/// kResidueInPrologue
true,
/// kLaunchBounds
false
> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
struct HgemmTransformerA {};
template <typename Iterator_>
struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
};
template <typename Iterator_>
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
typedef HgemmSwizzle<Iterator_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
struct HgemmTransformerB {};
template <typename Iterator_>
struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
};
template <typename Iterator_>
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
typedef HgemmSwizzle<Iterator_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
: public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> Base;
/// The traits class to build the iterator to load data from global memory for A^T.
typedef HgemmCrosswiseGlobalTileTraits<
GemmOperand::kA,
// The layout.
MatrixLayout::kRowMajor,
// The pointer.
half const,
// The tile has size MxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
// The threads are distributed as (threads / K ) x K (the traits may reorganize).
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
/// The traits class to build the iterator to store data to shared memory for A^T.
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer.
half,
// The tile has size KxM in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as warps x 32(the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
2,
// The skew to avoid bank conflicts added in the tile W dimension.
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^T.
typedef GemmSharedLoadTileATraits<
// The pointer.
half const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
8,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_>
struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
: public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> Base;
/// The traits class to build the iterator to load data from global memory for B^N.
typedef HgemmCrosswiseGlobalTileTraits<
GemmOperand::kB,
// The layout.
MatrixLayout::kColumnMajor,
// The pointer.
half const,
// The tile has size KxN in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits <
// The pointer.
half,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages,
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
2,
// The skew to avoid bank conflicts added in the tile W dimension.
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
// The pointer.
half const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
8,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_,
/// The number of halfs loaded in one LDG for A.
int kScalarsPerLdgA_ = 2,
/// The number of halfs loaded in one LDG for B.
int kScalarsPerLdgB_ = 2,
/// The index.
typename Index_ = int>
struct HgemmTraitsHelper {
/// The HGEMM config.
typedef HgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_> GemmConfig;
/// The GEMM config for A.
typedef HgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
/// The GEMM config for B.
typedef HgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
/// The iterator to load A from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
GlobalLoadIteratorA;
/// The default transformer for A.
typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
/// The iterator to store A to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorA;
/// The stream to load A from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kA,
GlobalLoadIteratorA,
SharedStoreIteratorA,
GlobalTransformerA>
GlobalLoadStreamA;
/// The iterator to load B from global memory.
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
GlobalLoadIteratorB;
// The default transformer for B.
typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
/// The iterator to store B to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorB;
/// The stream to load B from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kB,
GlobalLoadIteratorB,
SharedStoreIteratorB,
GlobalTransformerB>
GlobalLoadStreamB;
/// The iterator to load A from shared memory
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorA;
/// The stream to load A from shared memory.
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
/// The iterator to load B from shared memory.
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorB;
/// The stream to load B from shared memory.
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
/// The functor to do the multiply-add in the main loop.
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
/// The object to clear accumulators.
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
/// The traits class for the epilogue.
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_> GemmEpilogueTraits;
/// The epilogue.
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_ = Shape<8, 128, 128>,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_ = LinearScaling<half>,
/// Tile size for warp-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<8, 8, 16>,
/// The number of halfs loaded in one LDG for A.
int kScalarsPerLdgA_ = 2,
/// The number of halfs loaded in one LDG for B.
int kScalarsPerLdgB_ = 2,
/// The index.
typename Index_ = int,
/// The helper class.
typename Helper_ = HgemmTraitsHelper<kLayoutA_,
kLayoutB_,
OutputTile_,
EpilogueFunctor_,
ThreadGemmShape_,
kScalarsPerLdgA_,
kScalarsPerLdgB_,
Index_> >
struct HgemmTraits : public GemmTraits<
// The config.
typename Helper_::GemmConfig,
// The stream to load A from global memory to shared memory.
typename Helper_::GlobalLoadStreamA,
// The stream to load B from global memory to shared memory.
typename Helper_::GlobalLoadStreamB,
// The stream to load A from shared memory.
typename Helper_::SharedLoadStreamA,
// The stream to load B from shared memory.
typename Helper_::SharedLoadStreamB,
// The epilogue.
typename Helper_::Epilogue,
// The block swizzle to reorganize the grid.
IdentityBlockSwizzle,
// The index.
Index_,
// The tool used to clear accumulators.
typename Helper_::ClearAccumulators> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,318 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and
floating-point output matrix formats.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/fragment.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/igemm_global_tile.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kElements_>
struct IgemmFloatToInt8Converter {
/// The input fragment.
typedef Fragment<float, kElements_> InputFragment;
/// The output fragment.
typedef Fragment<int8_t, kElements_> OutputFragment;
// We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
/// Ctor.
CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
transform(src, 0, dst);
}
/// Transform a fragment.
template <typename Fragment_>
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
// The inputs.
float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
// The outputs.
int* dst_int = reinterpret_cast<int*>(&dst[0]);
// Iterate over the floats and pack them together to produce ints.
for (int i = 0; i < kElements_ / 4; ++i) {
// Read the float4.
float4 f4 = src_f4[i];
// Clamp the 4 elements of the floats to the [-128, +127] range.
float x = fmaxf(-128.f, fminf(127.f, f4.x));
float y = fmaxf(-128.f, fminf(127.f, f4.y));
float z = fmaxf(-128.f, fminf(127.f, f4.z));
float w = fmaxf(-128.f, fminf(127.f, f4.w));
// Convert to integers.
int ix = (int)x;
int iy = (int)y;
int iz = (int)z;
int iw = (int)w;
// Extract the lower bytes to build an int32 with 4 int8.
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
// Store the int.
dst_int[i] = ix;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputScalar_, typename OutputFragment_>
struct IgemmGlobalStoreTransformer {
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
};
template <int kElements_>
struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
typedef IgemmFloatToInt8Converter<kElements_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kElements_>
struct IgemmInt8ToFloatConverter {
/// The input fragment.
typedef Fragment<int8_t, kElements_> InputFragment;
/// The output fragment.
typedef Fragment<float, kElements_> OutputFragment;
// We are unpacking 4 int8s from int32.
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
/// Ctor.
CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
transform(src, 0, dst);
}
/// Transform a fragment.
template <typename Fragment_>
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
// The inputs.
int const* src_int = reinterpret_cast<int const*>(&src[0]);
// The outputs.
float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
// Iterate over the int8 and unpack them together to produce floats.
for (int i = 0; i < kElements_ / 4; ++i) {
// Read the int.
int ix, iy, iz, iw = src_int[i];
// Extract the 4 bytes.
asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
// The floats.
float fx, fy, fz, fw;
// Convert to floats (make sure we generate I2F.F32.S8).
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
// Store the float4.
dst_f4[i] = make_float4(fx, fy, fz, fw);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputFragment_, typename OutputScalar_>
struct IgemmGlobalLoadTransformer {
typedef Convert<InputFragment_, Fragment<OutputScalar_, InputFragment_::kElements> > Transformer;
};
template <int kElements_>
struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
typedef IgemmInt8ToFloatConverter<kElements_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputScalar_, typename OutputFragment_>
struct IgemmSharedStoreTransformer {
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
struct IgemmEpilogueTraitsHelper
: public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
/// The base class.
typedef GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> Base;
/// The config.
typedef IgemmConfig_ IgemmConfig;
/// The scalar type of the epilogue.
typedef typename Base::Scalar Scalar;
/// The iterations.
typedef typename Base::Iterations Iterations;
/// The iterations strides.
typedef typename Base::Delta Delta;
/// The traits class for the iterator.
typedef typename Base::GlobalLoadTileTraits GlobalLoadTileTraits;
/// The iterator to store to shared memory.
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits> GlobalLoadIteratorC;
/// The fragment that needs to be produced by the load iterator.
typedef typename GlobalLoadIteratorC::Fragment GlobalFragmentC;
/// The transformer from loaded data to math fragment.
typedef
typename IgemmGlobalLoadTransformer<GlobalFragmentC, Scalar>::Transformer GlobalTransformerC;
/// The traits class for the iterator.
typedef typename Base::GlobalStoreTileTraits GlobalStoreTileTraits;
/// The iterator to store to shared memory.
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits> GlobalStoreIteratorD;
/// The fragment that needs to be passed to that store iterator.
typedef typename GlobalStoreIteratorD::Fragment GlobalFragmentD;
/// The transformer from accumulators to shared memory fragments.
typedef
typename IgemmGlobalStoreTransformer<Scalar, GlobalFragmentD>::Transformer GlobalTransformerD;
/// The traits class for the shared iterator to store D to shared memory.
typedef typename Base::SharedStoreTileTraits SharedStoreTileTraits;
/// The shared iterator to store D to shared memory.
typedef TileStoreIterator<SharedStoreTileTraits,
typename SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kGlobal>
SharedStoreIteratorD;
/// The fragment that needs to be passed to that store iterator.
typedef typename SharedStoreIteratorD::Fragment SharedStoreFragmentD;
/// The transformer from accumulators to shared memory fragments.
typedef typename IgemmSharedStoreTransformer<typename IgemmConfig::Accumulators::Element,
SharedStoreFragmentD>::Transformer
SharedStoreTransformerD;
/// The traits class for the shared iterator to load D from shared memory.
typedef typename Base::SharedLoadTileTraits SharedLoadTileTraits;
/// The shared iterator to load D from shared memory.
typedef TileLoadIterator<SharedLoadTileTraits,
typename SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorD;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The config.
typename IgemmConfig_,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_,
/// The index.
typename Index_ = int,
/// The helper class to assemble the traits.
typename Helper_ = IgemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> >
struct IgemmEpilogueTraits : public GemmEpilogueTraits<
// The output tile.
typename IgemmConfig_::OutputTile,
// The accumulators.
typename IgemmConfig_::Accumulators,
// The global iterator for C.
typename Helper_::GlobalLoadIteratorC,
// The transformer for C.
typename Helper_::GlobalTransformerC,
// The transformer for D.
typename Helper_::GlobalTransformerD,
// The global iterator for D.
typename Helper_::GlobalStoreIteratorD,
// The iterator to store D to shared memory.
typename Helper_::SharedStoreIteratorD,
// The shared store transformer for D.
typename Helper_::SharedStoreTransformerD,
// The stream to load D from shared memory.
typename Helper_::SharedLoadStreamD,
// The iterations.
typename Helper_::Iterations,
// The strides between iterations.
typename Helper_::Delta,
// The functor to be used in the epilogue.
EpilogueFunctor_,
// The index.
Index_> {
/// Do we output in int8?
static bool const kInt8Output =
platform::is_same<typename IgemmConfig_::ScalarC, int8_t>::value != 0;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
/// The base class.
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
/// Ctor.
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
typename Base::SharedStorage& shared_storage_,
Coord<3> const& _problem_size)
: Base(params_, shared_storage_, _problem_size) {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmEpilogueTraits_>
struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
/// The base class.
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
/// Ctor.
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
typename Base::SharedStorage& shared_storage_,
Coord<3> const& _problem_size)
: Base(params_, shared_storage_, _problem_size) {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,135 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements tile iterators to partition the thread block tile into 2D subtiles and
efficiently load each. Applies permute transformation to construct 'interleaved K-strided'
data layout in which 4-element dot products from the same K index are arranged in consecutive
locations within shared memory.
Supports efficient loads from shared memory to target the DP4A instruction.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/matrix_traits.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
// Which GEMM operand?
kOperand_,
// The layout.
kLayout_,
// The scalar.
Scalar_,
// The tile.
Tile_,
// The threads.
Threads_,
// The number of scalars per LDG/STG.
kAccessSize_> {
/// The base class.
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
/// The threads.
typedef typename Base::Threads Threads;
/// The strides in each dimension between different loads/stores.
typedef Shape<Base::Threads::kH * 4, 1, Base::Threads::kW, Base::kAccessSize> Delta;
/// The number of iterations needed to load/store the tile.
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 4,
4,
Base::VectorizedTile::kW / Base::Threads::kW,
Base::VectorizedTile::kC / Base::kAccessSize>
Iterations;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
}
};
public:
/// The threads strides.
typedef Shape<1, 4, Base::VectorizedTile::kC> ThreadsDelta;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits_, typename Index_ = int>
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
/// The base class.
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
/// The functor to compute the thread offset.
typedef typename TileTraits_::ThreadOffset ThreadOffset;
/// Constructor.
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
const Coord<3>& threadblock_offset,
ThreadOffset thread_offset_func = ThreadOffset())
: Base(_params, threadblock_offset, thread_offset_func), mask_(0xffffffff) { }
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& threadblock_offset) {
Base::initialize_predicates(bounds, threadblock_offset);
// The number of elements read in a single iteration.
int const kBlock = TileTraits_::Tile::kW;
// The residue.
int const kResidue = (int)(bounds[1] % kBlock);
// Compute the number of elements that are valid.
int const left = kResidue - Base::thread_offset[2];
if (left > 0 && left < 4) {
mask_ = (1u << (8 * left)) - 1u;
}
}
CUTLASS_DEVICE void load_element(
typename Base::AccessType& value, int d, int h, int w, int c) const {
Base::load_element(value, d, h, w, c);
reinterpret_cast<uint32_t&>(value) &= mask_;
}
/// The mask to clean up the values.
uint32_t mask_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,94 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements matrix multiply accumulate operation of 8-bit integer data using DP4A
instruction.
*/
#pragma once
#include "cutlass/fragment.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Template performing matrix multiply-add operation within a thread
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int> {
/// The shape of the instruction.
typedef Shape<4, 1, 1> InstructionShape;
/// Shape of the thread-level GEMM (K-by-N-by-M)
typedef ThreadGemmShape_ ThreadGemmShape;
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of accumulators per warp.
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
/// The type for A.
typedef int8_t ScalarA;
/// The fragment for A.
typedef Fragment<ScalarA, AccumulatorsPerThread::kW * 4> FragmentA;
/// The type for B.
typedef int8_t ScalarB;
/// The fragment for B.
typedef Fragment<ScalarB, AccumulatorsPerThread::kH * 4> FragmentB;
/// The type for C and D.
typedef int ScalarC;
/// The accumulators.
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
/// Multiply : d = a*b + c.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
// The inputs.
int const* a_int = reinterpret_cast<int const*>(&a[0]);
int const* b_int = reinterpret_cast<int const*>(&b[0]);
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
: "=r"(d[j * AccumulatorsPerThread::kW + i])
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
}
}
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,125 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Transposes a fragment of data containing packed 8-bit integer elements.
*/
#pragma once
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GlobalIterator_>
struct IgemmSwizzle {
/// The global iterator.
typedef GlobalIterator_ GlobalIterator;
/// The source fragment.
typedef typename GlobalIterator::Fragment Fragment;
/// The shape of the source fragment.
typedef typename GlobalIterator::FragmentShape FragmentShape;
/// The source fragment.
typedef Fragment InputFragment;
/// The destination fragment.
typedef Fragment OutputFragment;
/// The src/dst must be int8 fragments.
static_assert((platform::is_same<typename Fragment::Element, int8_t>::value), "Works on int8");
/// The number of elements must be a multiple of 4.
static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
"Not multiple of 4");
/// Ctor.
CUTLASS_DEVICE IgemmSwizzle() {}
/// Transform a fragment.
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
// Expose src/dst as int arrays.
int const* src_int = reinterpret_cast<int const*>(&src[0]);
int* dst_int = reinterpret_cast<int*>(&dst[0]);
// Transpose the data.
for (int d = 0; d < FragmentShape::kD; ++d) {
for (int h = 0; h < FragmentShape::kH / 4; ++h) {
for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
(4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
(4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
(4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
(4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
int a0 = src_int[i0];
int a1 = src_int[i1];
int a2 = src_int[i2];
int a3 = src_int[i3];
// // DEBUG.
// if (threadIdx.x == 0) {
// printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3);
// }
int b0, b1, b2, b3, c0;
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
// // DEBUG.
// if (threadIdx.x == 0) {
// printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3);
// }
dst_int[i0] = b0;
dst_int[i1] = b1;
dst_int[i2] = b2;
dst_int[i3] = b3;
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,550 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed
to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output
formats vary.
*/
#pragma once
#include "cutlass/convert.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_epilogue.h"
#include "cutlass/gemm/gemm_epilogue_traits.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/gemm/gemm_shared_tile.h"
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/igemm_epilogue.h"
#include "cutlass/gemm/igemm_global_tile.h"
#include "cutlass/gemm/igemm_multiply_add.h"
#include "cutlass/gemm/igemm_swizzle.h"
#include "cutlass/reshape_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// The output type.
typename ScalarD_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_>
struct IgemmConfig : public GemmConfig<
/// The scalar type for A.
int8_t,
/// The scalar type for B.
int8_t,
/// The scalar type for C.
ScalarD_,
/// The scalar type for D.
ScalarD_,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
/// The number of scalars per LDG for A.
4,
/// The number of scalars per STS for A.
4,
/// The number of scalars per LDS for A.
16,
/// The number of scalars per LDG for B.
4,
/// The number of scalars per STS for B.
4,
/// The number of scalars per LDS for B.
16,
/// The number of scalars per LDG for C and STG for D.
1,
/// The number of scalars per STS for D.
4,
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
2,
/// kResidueSeparate
false,
/// kResidueInPrologue
false,
/// kLaunchBounds
false> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename OutputTile_, typename ThreadGemmShape_>
struct IgemmConfig<OutputTile_, int8_t, ThreadGemmShape_>
: public GemmConfig<
/// The scalar type for A.
int8_t,
/// The scalar type for B.
int8_t,
/// The scalar type for C.
int8_t,
/// The scalar type for D.
int8_t,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
/// The number of scalars per LDG for A.
4,
/// The number of scalars per STS for A.
4,
/// The number of scalars per LDS for A.
16,
/// The number of scalars per LDG for B.
4,
/// The number of scalars per STS for B.
4,
/// The number of scalars per LDS for B.
16,
/// The number of scalars per LDG for C and STG for D.
4,
/// The number of scalars per STS for D.
4,
/// The number of scalars per LDS for D.
4,
/// The number of stages in shared memory.
2,
/// If true, separate mainloop is instantiated from residue
false,
/// Compute residue in prolog?
true,
/// Launch bounds?
false> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
/// The number of scalars per LDG/STS/LDS for A.
static int const kScalarsPerStsA = 16;
/// The traits class to build the iterator to load data from global memory for A^N.
typedef IgemmGlobalTileTraits<
GemmOperand::kA,
// The layout.
MatrixLayout::kColumnMajor,
// The pointer is float const.
int8_t const,
// The tile has size KxM in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
/// The global load iterator.
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for A^N.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
int8_t,
// The tile has size KxM in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
// The threads are distributed as warps x 32 (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
kScalarsPerStsA>
SharedStoreTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
/// The input scalar.
typedef int8_t Scalar;
/// The scalar stored in shared memory.
typedef int8_t MultiplyAddScalar;
/// The number of scalars per LDG/STS/LDS for A.
static int const kScalarsPerStsA = 16;
/// The traits class to build the iterator to load data from global memory for A^T.
typedef IgemmGlobalTileTraits<
GemmOperand::kA,
// The layout.
MatrixLayout::kRowMajor,
// The pointer is float const.
int8_t const,
// The tile has size NxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgA>
GlobalTileTraits;
/// The global load iterator.
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for A^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
// The pointer is int8.
int8_t,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
kScalarsPerStsA,
// The skew to avoid bank conflicts added in the tile W dimension.
16>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for A^N.
typedef GemmSharedLoadTileATraits<
// The pointer is float const.
int8_t const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
16,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
/// The layout.
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
/// The input scalar.
typedef int8_t Scalar;
/// The scalar stored in shared memory.
typedef int8_t MultiplyAddScalar;
/// The number of scalars per LDG/STS/LDS for B.
static int const kScalarsPerStsB = 16;
/// The traits class to build the iterator to load data from global memory for B^T.
typedef IgemmGlobalTileTraits<
GemmOperand::kB,
// The layout.
MatrixLayout::kColumnMajor,
// The pointer is float const.
int8_t const,
// The tile has size NxK in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
/// The global load iterator.
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreWithSkewTileAbTraits<
// The pointer is int8.
int8_t,
// The tile has size KxN in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
// The threads are distributed as (threads / K) x K (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS.
kScalarsPerStsB,
// The skew to avoid bank conflicts added in the tile W dimension.
16>
SharedStoreTileTraits;
/// The traits class to build the iterator to load from shared memory for B^N.
typedef GemmSharedLoadTileBTraits<
// The pointer is float const.
int8_t const,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The number of threads per warp.
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
// The shape of the FMA instruction.
typename GemmConfig_::InstructionShape,
// The number of stages.
GemmConfig_::kStages,
// The number of scalars per LDS.
16,
// The skew.
SharedStoreTileTraits::kSkew>
SharedLoadTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Index_>
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
/// The base config.
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
/// The number of scalars per LDG/STS/LDS for B.
static int const kScalarsPerStsB = 16;
/// The traits class to build the iterator to load data from global memory for B^T.
typedef IgemmGlobalTileTraits<
GemmOperand::kB,
// The layout.
MatrixLayout::kRowMajor,
// The pointer is float const.
int8_t const,
// The tile has size KxM in GEMM's terminology.
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgB>
GlobalTileTraits;
/// The global load iterator.
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
/// The traits class to build the iterator to store data to shared memory for B^N.
typedef GemmSharedStoreTileAbTraits<
// The pointer is float.
int8_t,
// The tile has size KxM in GEMM's terminology.
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
// The threads are distributed as warps x 32 (the traits may reorganize).
typename GlobalTileTraits::Threads,
// The number of scalars per STS (STS.32 or STS.128, etc).
kScalarsPerStsB>
SharedStoreTileTraits;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
struct IgemmTransformerA {};
template <typename Iterator_>
struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
typedef Copy<typename Iterator_::Fragment> Transformer;
};
template <typename Iterator_>
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
typedef IgemmSwizzle<Iterator_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
struct IgemmTransformerB {};
template <typename Iterator_>
struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
typedef Copy<typename Iterator_::Fragment> Transformer;
};
template <typename Iterator_>
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
typedef IgemmSwizzle<Iterator_> Transformer;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_,
/// The output type.
typename ScalarD_,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<32, 8, 8>,
/// The index.
typename Index_ = int>
struct IgemmTraitsHelper {
/// The IGEMM config.
typedef IgemmConfig<OutputTile_, ScalarD_, ThreadGemmShape_> GemmConfig;
/// The GEMM config for A.
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
/// The GEMM config for B.
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
/// The iterator to load A from global memory.
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
/// The default transformer for A.
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
/// The iterator to store A to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorA;
/// The stream to load A from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kA,
GlobalLoadIteratorA,
SharedStoreIteratorA,
GlobalTransformerA>
GlobalLoadStreamA;
/// The iterator to load B from global memory.
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
// The default transformer for B.
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
/// The iterator to store B to shared memory.
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedStoreIteratorB;
/// The stream to load B from global memory to shared memory.
typedef GlobalLoadStream<GemmOperand::kB,
GlobalLoadIteratorB,
SharedStoreIteratorB,
GlobalTransformerB>
GlobalLoadStreamB;
/// The iterator to load A from shared memory.
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorA;
/// The stream to load A from shared memory.
typedef SharedLoadStream<SharedLoadIteratorA, Copy<typename SharedLoadIteratorA::Fragment> >
SharedLoadStreamA;
/// The iterator to load B from shared memory.
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorB;
/// The stream to load B from shared memory.
typedef SharedLoadStream<SharedLoadIteratorB, Copy<typename SharedLoadIteratorB::Fragment> >
SharedLoadStreamB;
/// The multiply-add functor.
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
/// The object to clear accumulators.
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
/// The epilogue.
typedef IgemmEpilogue<IgemmEpilogueTraits<GemmConfig, EpilogueFunctor_> > Epilogue;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename ScalarD_>
struct IgemmEpilogueScalar {
typedef float Scalar;
};
template <>
struct IgemmEpilogueScalar<int> {
typedef int Scalar;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_ = Shape<32, 128, 128>,
/// The output type.
typename ScalarD_ = int,
/// The functor to do the math in the epilogue.
typename EpilogueFunctor_ = LinearScaling<typename IgemmEpilogueScalar<ScalarD_>::Scalar>,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<32, 8, 8>,
/// The index.
typename Index_ = int,
/// The helper class.
typename Helper_ = IgemmTraitsHelper<kLayoutA_,
kLayoutB_,
OutputTile_,
ScalarD_,
EpilogueFunctor_,
ThreadGemmShape_,
Index_> >
struct IgemmTraits : public GemmTraits<
// The config.
typename Helper_::GemmConfig,
// The stream to load A from global memory to shared memory.
typename Helper_::GlobalLoadStreamA,
// The stream to load B from global memory to shared memory.
typename Helper_::GlobalLoadStreamB,
// The stream to load A from shared memory.
typename Helper_::SharedLoadStreamA,
// The stream to load B from shared memory.
typename Helper_::SharedLoadStreamB,
// The epilogue.
typename Helper_::Epilogue,
// The block swizzle to reorganize the grid.
IdentityBlockSwizzle,
// The index.
Index_,
// The tool used to clear accumulators.
typename Helper_::ClearAccumulators> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,169 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
*/
#pragma once
#include "cutlass/fragment_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
CUTLASS_DEVICE bool is_zero(T x) {
return x == T(0);
}
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor to compute linear combination of fragments
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
struct LinearScaling {
// The scalar.
typedef Scalar_ Scalar;
// The accumulator Type
typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
// The adapater.
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
/// The parameters.
struct Params {
/// The alpha/beta scaling params.
Scalar alpha, beta;
//
// Methods
//
// Constructor
CUTLASS_HOST_DEVICE
Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {}
/// Initialize the parameters
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) {
alpha = _alpha;
beta = _beta;
return 0;
}
/// Initialize the parameters.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
alpha = desc.alpha;
beta = desc.beta;
return 0;
}
};
//
// Data members
//
Params params;
//
// Methods
//
/// Ctor.
CUTLASS_DEVICE LinearScaling() { }
/// Ctor.
CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
/// Method to determine whether the source accumulator matrix C is ever needed. This method
/// may always safely return true, though better performance is possible if the source accumulator
/// matrix is never loaded unnecessarily.
CUTLASS_DEVICE
bool source_required() const {
return !is_zero(params.beta);
}
/// Evaluate the functor.
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
FragmentMultiplyAdd mad;
mad.multiply(params.alpha, accum, output);
}
/// Evaluate the functor, without using fragment in the API
template <typename ScalarAccum, typename ScalarOutput, int size>
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
Fragment<ScalarAccum, size> FragAccum;
Fragment<ScalarOutput, size> FragOutput;
#pragma unroll
for (int i = 0; i < size; i++) {
FragAccum[i] = accum[i];
FragOutput[i] = output[i];
}
evaluate(FragAccum, FragOutput);
#pragma unroll
for (int i = 0; i < size; i++) {
output[i] = FragOutput[i];
}
}
/// Evaluate the functor.
template <typename FragmentA_, typename FragmentB_>
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
FragmentMultiplyAdd mad;
FragmentB_ tmp;
mad.multiply(params.beta, old, tmp);
mad.multiply_add(params.alpha, accum, tmp, output);
}
/// Evaluate the functor, without using fragment in the API
template <typename ScalarAccum, typename ScalarOutput, int size>
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
Fragment<ScalarAccum, size> FragAccum;
Fragment<ScalarOutput, size> FragOutput;
Fragment<ScalarOutput, size> FragOld;
#pragma unroll
for (int i = 0; i < size; i++) {
FragAccum[i] = accum[i];
FragOutput[i] = output[i];
FragOld[i] = old[i];
}
evaluate(FragAccum, FragOld, FragOutput);
#pragma unroll
for (int i = 0; i < size; i++) {
output[i] = FragOutput[i];
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,149 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/scalar_or_pointer.h"
#include "cutlass/gemm/linear_scaling.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Functor to compute linear combination of fragments. This is intended to support passing scalars
/// either by value from the host or by reference to device-side scalar elements. This is inspired
/// by cuBLAS's device pointer mode.
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
struct LinearScalingDevicePtr : public LinearScaling<Scalar_, FragmentMultiplyAdd_> {
/// Linear Scaling class used
typedef LinearScaling<Scalar_, FragmentMultiplyAdd_> Base;
// The scalar.
typedef typename Base::Scalar Scalar;
/// The parameters.
class Params {
private:
/// Alpha scalar
detail::ScalarOrPointer<Scalar> alpha_;
/// Beta sclaar
detail::ScalarOrPointer<Scalar> beta_;
public:
//
// Methods
//
// Constructor
CUTLASS_HOST_DEVICE
Params() {}
// Constructor
CUTLASS_HOST_DEVICE
Params(
Scalar alpha,
Scalar beta
):
alpha_(alpha),
beta_(beta) {}
// Constructor
CUTLASS_HOST_DEVICE
Params(
Scalar const *alpha_ptr,
Scalar const *beta_ptr
):
alpha_(alpha_ptr),
beta_(alpha_ptr) {}
/// Initialize the parameters
CUTLASS_HOST_DEVICE int initialize(
Scalar alpha,
Scalar beta) {
alpha_ = alpha;
beta_ = beta;
return 0;
}
/// Initialize the parameters
CUTLASS_HOST_DEVICE int initialize(
Scalar const *alpha,
Scalar const *beta) {
alpha_ = alpha;
beta_= beta;
return 0;
}
/// Initialize the parameters.
template <typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
alpha_ = desc.alpha;
beta_ = desc.beta;
return 0;
}
/// Gets the alpha scalar
CUTLASS_HOST_DEVICE
Scalar alpha() const {
return alpha_;
}
/// Gets the beta scalar
CUTLASS_HOST_DEVICE
Scalar beta() const {
return beta_;
}
};
//
// Methods
//
/// Ctor.
CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const& _params) {
this->params.alpha = _params.alpha();
this->params.beta = _params.beta();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,129 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
*/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Helper class defines an object which operates as either a scalar or a pointer. If the pointer
/// is non-null, it is dereferenced when the object is accessed.
template <typename Scalar_>
class ScalarOrPointer {
public:
/// Underlying scalar type
typedef Scalar_ Scalar;
private:
//
// Data members
//
/// Scalar value
Scalar scalar;
/// Pointer to use if non null
Scalar const *ptr;
public:
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
ScalarOrPointer(): scalar(0), ptr(nullptr) {}
/// Object behaves as a scalar
CUTLASS_HOST_DEVICE
ScalarOrPointer(Scalar const &val): scalar(val), ptr(nullptr) {}
/// Object behaves as a scalar
CUTLASS_HOST_DEVICE
ScalarOrPointer(Scalar const *ptr_): scalar(0), ptr(ptr_) {}
/// Returns true if is pointer
CUTLASS_HOST_DEVICE
bool is_pointer() const {
return bool(ptr);
}
/// Gets the pointer value
CUTLASS_HOST_DEVICE
Scalar const *get_ptr() const {
return ptr;
}
/// Gets the pointer value
CUTLASS_HOST_DEVICE
Scalar get_scalar() const {
return scalar;
}
/// Assigns to a scalar and sets pointer to nullptr
CUTLASS_HOST_DEVICE
ScalarOrPointer &operator=(Scalar const &scalar_) {
scalar = scalar_;
ptr = nullptr;
return *this;
}
/// Assigns to a pointer value
CUTLASS_HOST_DEVICE
ScalarOrPointer &operator=(Scalar const *ptr_) {
ptr = ptr_;
return *this;
}
/// Access the element
CUTLASS_HOST_DEVICE
Scalar get() const {
if (ptr) {
return *ptr;
}
return scalar;
}
/// Accesses the element
CUTLASS_HOST_DEVICE
operator Scalar() const {
return get();
}
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,172 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies structural properties of single-precision GEMM.
*/
#pragma once
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/gemm_epilogue.h"
#include "cutlass/gemm/gemm_epilogue_traits.h"
#include "cutlass/gemm/gemm_global_tile.h"
#include "cutlass/gemm/gemm_shared_tile.h"
#include "cutlass/gemm/gemm_traits.h"
#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The tile size for the GEMM KxNxM.
typename OutputTile_,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of scalars per LDG for B.
int kScalarsPerLdgB_ = 1,
/// Whether to specify launch bounds
bool kLaunchBounds = true>
struct SgemmConfig : public GemmConfig<
/// The scalar type for A.
float,
/// The scalar type for B.
float,
/// The scalar type for C.
float,
/// The scalar type for D.
float,
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, float, float, float>,
/// The number of scalars per LDG for A.
kScalarsPerLdgA_,
/// The number of scalars per STS for A.
kScalarsPerLdgA_,
/// The number of scalars per LDS for A.
4,
/// The number of scalars per LDG for B.
kScalarsPerLdgB_,
/// The number of scalars per STS for B.
kScalarsPerLdgB_,
/// The number of scalars per LDS for B.
4,
/// The number of scalars per LDG for C and STG for D.
1,
/// The number of scalars per STS for D.
4,
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
2,
/// kResidueSeparate
false,
/// kResidueInPrologue
true,
/// kLaunchBounds
kLaunchBounds> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_ = Shape<8, 128, 128>,
/// The functor to use in the epilogue.
typename EpilogueFunctor_ = LinearScaling<float>,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<8, 8, 8>,
/// The number of floats loaded in one LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of floats loaded in one LDG for B.
int kScalarsPerLdgB_ = 1,
/// The index.
typename Index_ = int,
/// The SGEMM config.
typename GemmConfig_ =
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, false>,
/// The traits class for the epilogue.
typename GemmEpilogueTraits_ =
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
struct SgemmTraits : public SimplifiedGemmTraits<
// The layout for A.
kLayoutA_,
// The layout for B.
kLayoutB_,
// The config.
GemmConfig_,
// The epilogue.
GemmEpilogue<GemmEpilogueTraits_>,
// The index.
Index_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to define SGEMM traits using Launch Bounds
template <
/// The layout for A.
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
/// The output tile.
typename OutputTile_ = Shape<8, 128, 128>,
/// The functor to use in the epilogue.
typename EpilogueFunctor_ = LinearScaling<float>,
/// Tile size for thread-level GEMM (K-by-N-by-M)
typename ThreadGemmShape_ = Shape<8, 8, 8>,
/// The number of floats loaded in one LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of floats loaded in one LDG for B.
int kScalarsPerLdgB_ = 1,
/// The index.
typename Index_ = int,
/// The SGEMM config.
typename GemmConfig_ =
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, true>,
/// The traits class for the epilogue.
typename GemmEpilogueTraits_ =
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
struct SgemmLBTraits : public SimplifiedGemmTraits<
// The layout for A.
kLayoutA_,
// The layout for B.
kLayoutB_,
// The config.
GemmConfig_,
// The epilogue.
GemmEpilogue<GemmEpilogueTraits_>,
// The index.
Index_> {};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,96 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template implementing matrix multiply-add operations on fragments.
*/
#pragma once
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Template performing matrix multiply-add operation within a thread
template <typename ThreadGemmShape_,
typename ThreadsPerWarp_,
typename ScalarA_,
typename ScalarB_,
typename ScalarC_,
MatrixLayout::Kind kLayout_ = MatrixLayout::kColumnMajor>
struct ThreadMultiplyAdd {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
/// The shape of a thread-leveel matrix multiply accumulate.
typedef ThreadGemmShape_ ThreadGemmShape;
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
typedef ThreadGemmShape AccumulatorsPerThread;
/// The number of threads per warp.
typedef ThreadsPerWarp_ ThreadsPerWarp;
/// The number of accumulators per warp.
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
/// The type for A.
typedef ScalarA_ ScalarA;
/// The fragment for A.
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
/// The type for B.
typedef ScalarB_ ScalarB;
/// The fragment for B.
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
/// The type for C and D.
typedef ScalarC_ ScalarC;
/// The accumulators.
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
/// Ctor.
CUTLASS_DEVICE ThreadMultiplyAdd() {}
/// Multiply : d = a*b + c.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
if(kLayout_ == MatrixLayout::kColumnMajor) {
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
}
}
}
else {
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,447 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_coord.h"
namespace cutlass {
namespace gemm {
struct swizzleDirection {
enum Kind { Boustrophedon, OneDirection };
};
// helper template function
template <enum swizzleDirection::Kind>
CUTLASS_DEVICE int getLinearIdx(int groups) {
// groupCols is not needed for OneDirection Swizzle
return blockIdx.y * gridDim.x + blockIdx.x;
}
template <>
CUTLASS_DEVICE int getLinearIdx<swizzleDirection::Boustrophedon>(int groups) {
// reverse blockIdx.x for some columns
if ((blockIdx.y / groups) % 2 == 1)
return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
else
return blockIdx.y * gridDim.x + blockIdx.x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup IdentityBlockSwizzle Identity Block Swizzle
@{
Block Swizzle provides the mapping logic between a block in the physical memory of Matrix C and
Thread Block
Identiy Block Swizzle effective maps blocks in leading dimension order (column major) with
thread block
in leading dimension order (blockIdx.x)
blockIdx.z is mapped with batch_count for batched GEMM
@}
*/
struct IdentityBlockSwizzle {
/// Ctor. aka ColumnMajorBlockSwizzle<1>
CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {}
/// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
Coord<3> const &OutputTile) {
/*OutputTile and problem_size are both in KNM order*/
dim3 grid;
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
grid.z = problem_size.batch();
return grid;
}
///get threadblock offset, without considering tha batch dim
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
dim3 block = swizzle();
Coord<3> threadblock_offset =
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
return threadblock_offset;
}
///
CUTLASS_DEVICE int get_batch_id() {
dim3 block = swizzle();
return block.z;
}
/// check if at the last partition
CUTLASS_DEVICE bool is_last_partition() {
if (get_batch_id() == (gridDim.z - 1))
return true;
else
return false;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
int partitionK_range) {
// every partition except the last one has a smaller range
// partitionK_range is the bounds for every partition except the last one
// the last partition's bounds is the same with problem size
if(is_last_partition())
return problem_size.knm();
else
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle
groupCols has the effect of controlling the schedulling of thread blocks
settings with different groupCols can contribute to the overall performance by affecting L2 cache
hit rate
consider a regular thread block mapping btween matrix C and different thread blocks
note that C is column major, and the leading dimension of thread block id is blockIdx.x
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
(blockIdx.x, blockIdx.y)
mapping between threadblockID and C matrix:
-------------------------------------------------------
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
-------------------------------------------------------
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
-------------------------------------------------------
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
-------------------------------------------------------
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
-------------------------------------------------------
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
-------------------------------------------------------
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
-------------------------------------------------------
A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are
launched in a column major
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little,
-------------------------------------------------------
(0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) |
-------------------------------------------------------
(0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) |
-------------------------------------------------------
(1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) |
-------------------------------------------------------
(1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) |
-------------------------------------------------------
(2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) |
-------------------------------------------------------
(2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) |
-------------------------------------------------------
so in memory, it would apprear that we work on 2 columns at a time rather than 1
Note that the index here really represent how each block maps to memory
A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection>
except that every column flips the ordering against the previous one
-------------------------------------------------------
(0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) |
-------------------------------------------------------
(1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) |
-------------------------------------------------------
(2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) |
-------------------------------------------------------
(3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) |
-------------------------------------------------------
(4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) |
-------------------------------------------------------
(5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) |
-------------------------------------------------------
similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like
-------------------------------------------------------
(0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) |
-------------------------------------------------------
(0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) |
-------------------------------------------------------
(1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) |
-------------------------------------------------------
(1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) |
-------------------------------------------------------
(2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) |
-------------------------------------------------------
(2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) |
-------------------------------------------------------
*/
template <int groupCols, enum swizzleDirection::Kind swDirection>
struct ColumnMajorBlockSwizzle {
/// Ctor.
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle() {}
/// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() {
assert(gridDim.z == 1);
int linearIdx = getLinearIdx<swDirection>(groupCols);
dim3 swizzledBlockIdx;
int currGroupCols = groupCols;
int prevGroupCols = groupCols;
if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
// last colmuns if gridDim.y is not divisble by groupCols
currGroupCols = gridDim.y % groupCols;
}
swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
swizzledBlockIdx.y =
linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
swizzledBlockIdx.z = blockIdx.z;
return swizzledBlockIdx;
}
///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
Coord<3> const &OutputTile) {
dim3 grid;
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
grid.z = problem_size.batch();
return grid;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
dim3 block = swizzle();
Coord<3> threadblock_offset =
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
return threadblock_offset;
}
///
CUTLASS_DEVICE int get_batch_id() {
dim3 block = swizzle();
return block.z;
}
/// check if at the last partition
CUTLASS_DEVICE bool is_last_partition() {
if (get_batch_id() == (gridDim.z - 1))
return true;
else
return false;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
int partitionK_range) {
// every partition except the last one has a smaller range
// partitionK_range is the bounds for every partition except the last one
// the last partition's bounds is the same with problem size
if (is_last_partition())
return problem_size.knm();
else
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
consider a regular thread block mapping btween matrix C and different thread blocks
note that C is column major, and the leading dimension of thread block id is blockIdx.x
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
(blockIdx.x, blockIdx.y)
mapping between threadblockID and C matrix:
-------------------------------------------------------
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
-------------------------------------------------------
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
-------------------------------------------------------
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
-------------------------------------------------------
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
-------------------------------------------------------
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
-------------------------------------------------------
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
-------------------------------------------------------
A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map
-----------------------------------------------
(0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) |
-----------------------------------------------
(0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) |
-----------------------------------------------
(0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) |
-----------------------------------------------
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
-----------------------------------------------
(0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) |
---------------------------------------------
(0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) |
-----------------------------------------------
(0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) |
-----------------------------------------------
It would aprear in memory we are working on 1 row at a time
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more
-----------------------------------------------
(0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) |
-----------------------------------------------
(1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) |
-----------------------------------------------
(0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) |
-----------------------------------------------
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
-----------------------------------------------
(0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) |
---------------------------------------------
(1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) |
-----------------------------------------------
(0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) |
-----------------------------------------------
so in memory, it would apprear that we work on 2 rows at a time rather than 1 row
Note that the index here really represent how each block maps to memory
A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection>
except that every column flips the ordering against the previous one
-----------------------------------------------
(0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) |
-----------------------------------------------
(0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) |
-----------------------------------------------
(0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) |
-----------------------------------------------
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
-----------------------------------------------
(0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) |
---------------------------------------------
(0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) |
-----------------------------------------------
(0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) |
-----------------------------------------------
similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like
-----------------------------------------------
(0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) |
-----------------------------------------------
(1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) |
-----------------------------------------------
(0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) |
-----------------------------------------------
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
-----------------------------------------------
(0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) |
---------------------------------------------
(1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) |
-----------------------------------------------
(0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) |
-----------------------------------------------
*/
template <int groupRows, enum swizzleDirection::Kind swDirection>
struct RowMajorBlockSwizzle {
/// Ctor.
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle() {}
/// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() {
assert(gridDim.z == 1);
int linearIdx = getLinearIdx<swDirection>(groupRows);
dim3 swizzledBlockIdx;
int currGroupRows = groupRows;
int prevGroupRows = groupRows;
if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
// last columns
currGroupRows = gridDim.y % groupRows;
}
swizzledBlockIdx.x =
linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
swizzledBlockIdx.z = blockIdx.z;
return swizzledBlockIdx;
}
///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
Coord<3> const &OutputTile) {
dim3 grid;
grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
grid.z = problem_size.batch();
return grid;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
dim3 block = swizzle();
Coord<3> threadblock_offset =
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
return threadblock_offset;
}
///
CUTLASS_DEVICE int get_batch_id() {
dim3 block = swizzle();
return block.z;
}
/// check if at the last partition
CUTLASS_DEVICE bool is_last_partition() {
if (get_batch_id() == (gridDim.z - 1) )
return true;
else
return false;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_bounds(GemmCoord const &problem_size,
int partitionK_range) {
// every partition except the last one has a smaller range
// partitionK_range is the bounds for every partition except the last one
// the last partition's bounds is the same with problem size
if (is_last_partition())
return problem_size.knm();
else
return make_Coord(partitionK_range, problem_size.n(), problem_size.m());
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,167 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines structural properties of WMMA GEMM's epilogue phase.
*/
#pragma once
#include "cutlass/wmma_matrix.h"
#ifdef CUTLASS_USE_WMMA_API
#include "cutlass/convert.h"
#include "cutlass/coord.h"
#include "cutlass/gemm/gemm_global_stream.h"
#include "cutlass/gemm/gemm_shared_stream.h"
#include "cutlass/gemm/linear_scaling.h"
#include "cutlass/gemm/wmma_gemm_global_tile.h"
#include "cutlass/gemm/wmma_gemm_shared_tile.h"
#include "cutlass/reshape_tile.h"
#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename GemmConfig_, typename Accumulator_, typename EpilogueFunctor_, typename Index_ = int>
struct WmmaGemmEpilogueTraitsHelper {
/// The scalar.
typedef typename EpilogueFunctor_::Scalar Scalar;
/// The output tile.
typedef typename GemmConfig_::OutputTile OutputTile;
/// The number of WMMAs in the H dimension.
static int const kWmmasPerH =
GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
/// The number of iterations in the epilogue. That's the number of "horizontal" WMMAs.
typedef Shape<1, 1, kWmmasPerH> Iterations;
// The iteration strides in the H/W dimension.
typedef Shape<0, 0, 0> Delta;
/// The functor to do the math in the epilogue.
typedef EpilogueFunctor_ Functor;
/// The traits class to build the iterator to store to shared memory for D.
typedef WmmaGemmSharedStoreTileDTraits<
// The output layout.
MatrixLayout::kColumnMajor,
// The pointer is float.
typename Functor::Scalar,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
typename GemmConfig_::Warps,
// The shape of the instruction.
typename GemmConfig_::InstructionShape>
SharedStoreTileTraits;
typedef WmmaMatrix<GemmOperand::kC,
MatrixLayout::kColumnMajor,
Scalar,
typename GemmConfig_::InstructionShape>
WmmaMatrix;
/// The iterator to store D to shared memory.
typedef TileStoreIterator<SharedStoreTileTraits,
typename SharedStoreTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared,
Index_,
WmmaMatrix,
FragmentElementType::kWmmaMatrix>
SharedStoreIteratorD;
/// The shared store transformer for D.
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
/// The traits class to build the iterator to load from shared memory for D.
typedef WmmaGemmSharedLoadTileDTraits<
// The pointer.
typename Functor::Scalar,
// The tile size.
typename SharedStoreIteratorD::Tile,
// The number of threads.
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDS.
GemmConfig_::kScalarsPerLdsD,
// this parameter helps with swizzling when accum is fp32 and output is fp16
sizeof(Accumulator_) / sizeof(typename GemmConfig_::ScalarD)
>
SharedLoadTileTraits;
/// The iterator to load D from shared memory.
typedef TileLoadIterator<SharedLoadTileTraits,
typename SharedLoadTileTraits::Scalar,
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorD;
/// The stream to load D.
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
/// The traits class to build the iterator to load data from global memory for C^N.
typedef WmmaGemmGlobalIteratorCdTraits<
// The pointer is float const.
typename GemmConfig_::ScalarC const,
// The tile has size (N / Iterations)xM in GEMM's terminology.
Shape<1,
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerLdgC>
GlobalLoadTileTraits;
/// The iterator to load C.
typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
/// The transformer for C.
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
/// The traits class to build the iterator to store data to global memory for D^N.
typedef WmmaGemmGlobalIteratorCdTraits<
// The pointer is float.
typename GemmConfig_::ScalarD,
// The tile has size (N / Iterations)xM in GEMM's terminology.
Shape<1,
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
GemmConfig_::OutputTile::kW>,
// The threads are distributed as warps x 32 (the traits may reorganize).
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
GemmConfig_::kScalarsPerStgD>
GlobalStoreTileTraits;
/// The iterator to store D.
typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
/// The transformer for D.
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
#endif // defined CUTLASS_USE_WMMA_API

View File

@ -1,167 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines tile iterator traits for loading thread block-level tile from global memory.
*/
#pragma once
#include "cutlass/gemm/gemm_global_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
MatrixLayout::kColumnMajor,
Scalar_,
Tile_,
Threads_,
kAccessSize_> {
/// The base class.
typedef GemmGlobalTileTraits<GemmOperand::kC,
MatrixLayout::kColumnMajor,
Scalar_,
Tile_,
Threads_,
kAccessSize_>
Base;
/// Override the strides in each dimension between different loads/stores.
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int thread_offset_h = threadIdx.x / Base::Threads::kW;
int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename TileTraits_, typename Index_ = int>
struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index_> {
/// This class.
typedef WmmaGemmGlobalIteratorCd<TileTraits_, Index_> This_;
/// The traits.
typedef TileTraits_ Traits;
/// The base class.
typedef GemmGlobalIteratorCd<Traits, Index_> Base;
/// Override the strides in each dimension between different loads/stores.
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides;
/// The layout.
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
/// The scalar.
typedef typename TileTraits_::Scalar Scalar;
/// The pointer.
typedef typename TileTraits_::Pointer Pointer;
/// The threads.
typedef typename TileTraits_::Threads Threads;
/// The index.
typedef Index_ Index;
/// The thread offset functor.
typedef typename TileTraits_::ThreadOffset ThreadOffset;
/// Base parameters.
typedef typename Base::Params BaseParams;
/// The params.
struct Params : public BaseParams {
/// Setup the params.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
long long batch_stride,
Index ldm,
Index n,
Index epilogue_stride_w,
Index epilogue_delta_w) {
// The pointer.
this->pointer = pointer;
// Stride between GEMMs
this->stride_d = batch_stride;
// Setup the base stride. One "group of threads" per column.
this->stride_h = ldm;
// Each thread output 1 column per iteration. .
this->inc_h = ldm * TileTraits_::Threads::kH;
this->inc_advance = this->inc_h + epilogue_stride_w;
this->predicate_offset = n;
this->predicate_inc_h = TileTraits_::Threads::kH;
this->predicate_inc_advance = this->predicate_inc_h + epilogue_delta_w;
return 0;
}
};
/// Ctor.
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
const Coord<3>& bounds,
const Coord<3>& block,
int const pointer_offset = 0,
int const pred_offset = 0,
ThreadOffset thread_offset_func = ThreadOffset())
: Base(params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {}
/// Loads a single fragment element from memory
CUTLASS_DEVICE void load_element(
typename Base::AccessType& value, int d, int h, int w, int c) const {
Base::load_element(value, d, h, w, c);
}
/// Stores a single fragment element into memory
CUTLASS_DEVICE void store_element(
typename Base::AccessType const& value, int d, int h, int w, int c) {
int const offset =
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
Store<Scalar,
Base::kAccessSize,
Base::kMemorySpace,
Base::kFragmentElementType,
typename Base::FragmentElement,
Base::Tile::kW>::store(value, Base::params.pointer, offset);
}
public:
template <typename Fragment>
CUTLASS_DEVICE void load_post_increment(Fragment& fragment) {
Base::load_post_increment(fragment);
}
template <typename Fragment>
CUTLASS_DEVICE void store_post_increment(Fragment& fragment) {
Base::store_post_increment(fragment);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass

View File

@ -1,355 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
*/
#pragma once
#include "cutlass/wmma_matrix.h"
#ifdef CUTLASS_USE_WMMA_API
#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MatrixLayout::Kind kLayoutA_,
typename ScalarA_,
MatrixLayout::Kind kLayoutB_,
typename ScalarB_,
MatrixLayout::Kind kLayoutC_,
typename ScalarC_,
typename WarpGemmShape_,
typename InstructionShape_>
struct WmmaGemmMultiplyAdd {
/// The shape of the instruction.
typedef InstructionShape_ InstructionShape;
/// The number of threads per warp. That's a dummy configuration.
typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// The type for A.
typedef ScalarA_ ScalarA;
/// The type for B.
typedef ScalarB_ ScalarB;
/// The type for C and D.
typedef ScalarC_ ScalarC;
/// The number of iterations.
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
/// The element for A.
typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
/// The fragment for A.
typedef Fragment<ElementA, Iterations::kW> FragmentA;
/// The element for B.
typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
/// The fragment for B.
typedef Fragment<ElementB, Iterations::kH> FragmentB;
/// The element for C.
typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
/// The fragment for C.
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
/// Ctor.
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
/// Multiply : d = a*b.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < Iterations::kH; ++j) {
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
ElementB const& elt_b = b[j];
ElementC const& elt_c = c[j * Iterations::kW + i];
// The output element.
ElementC& elt_d = d[j * Iterations::kW + i];
// The wmma instruction.
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef CUTLASS_USE_SUBBYTE_WMMA
/// Specialization for WMMA GEMM with binary operands
template<typename WarpGemmShape_>
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
Vector<bin1_t, 32>,
MatrixLayout::kColumnMajor,
Vector<bin1_t, 32>,
MatrixLayout::kColumnMajor,
int,
WarpGemmShape_,
Shape<128, 8, 8> >{
/// The shape of the instruction.
typedef Shape<128, 8, 8> InstructionShape;
/// The number of threads per warp. That's a dummy configuration.
typedef Shape<1, 4, 8> ThreadsPerWarp;
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// The type for A.
typedef Vector<bin1_t, 32> ScalarA;
/// The type for B.
typedef Vector<bin1_t, 32> ScalarB;
/// The type for C and D.
typedef int ScalarC;
/// The number of iterations.
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
/// The element for A.
typedef WmmaMatrix<GemmOperand::kA,
MatrixLayout::kRowMajor,
Vector<bin1_t, 32>,
InstructionShape> ElementA;
/// The fragment for A.
typedef Fragment<ElementA, Iterations::kW> FragmentA;
/// The element for B.
typedef WmmaMatrix<GemmOperand::kB,
MatrixLayout::kColumnMajor,
Vector<bin1_t, 32>,
InstructionShape> ElementB;
/// The fragment for B.
typedef Fragment<ElementB, Iterations::kH> FragmentB;
/// The element for C.
typedef WmmaMatrix<GemmOperand::kC,
MatrixLayout::kColumnMajor,
int,
InstructionShape> ElementC;
/// The fragment for C.
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
/// Ctor.
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
/// Multiply : d = a*b.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < Iterations::kH; ++j) {
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
ElementB const& elt_b = b[j];
ElementC const& elt_c = c[j * Iterations::kW + i];
// The output element.
ElementC& elt_d = d[j * Iterations::kW + i];
// The wmma instruction.
nvcuda::wmma::bmma_sync(elt_d,
elt_a,
elt_b,
elt_c,
nvcuda::wmma::experimental::bmmaBitOpXOR,
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
}
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef CUTLASS_USE_SUBBYTE_WMMA
/// Specialization for WMMA GEMM with signed 4-bit integer operands
template<typename WarpGemmShape_>
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
Vector<int4_t, 8>,
MatrixLayout::kColumnMajor,
Vector<int4_t, 8>,
MatrixLayout::kColumnMajor,
int,
WarpGemmShape_,
Shape<32, 8, 8> >{
/// The shape of the instruction.
typedef Shape<32, 8, 8> InstructionShape;
/// The number of threads per warp. That's a dummy configuration.
typedef Shape<1, 4, 8> ThreadsPerWarp;
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// The type for A.
typedef Vector<int4_t, 8> ScalarA;
/// The type for B.
typedef Vector<int4_t, 8> ScalarB;
/// The type for C and D.
typedef int ScalarC;
/// The number of iterations.
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
/// The element for A.
typedef WmmaMatrix<GemmOperand::kA,
MatrixLayout::kRowMajor,
Vector<int4_t, 8>,
InstructionShape> ElementA;
/// The fragment for A.
typedef Fragment<ElementA, Iterations::kW> FragmentA;
/// The element for B.
typedef WmmaMatrix<GemmOperand::kB,
MatrixLayout::kColumnMajor,
Vector<int4_t, 8>,
InstructionShape> ElementB;
/// The fragment for B.
typedef Fragment<ElementB, Iterations::kH> FragmentB;
/// The element for C.
typedef WmmaMatrix<GemmOperand::kC,
MatrixLayout::kColumnMajor,
int,
InstructionShape> ElementC;
/// The fragment for C.
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
/// Ctor.
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
/// Multiply : d = a*b.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < Iterations::kH; ++j) {
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
ElementB const& elt_b = b[j];
ElementC const& elt_c = c[j * Iterations::kW + i];
// The output element.
ElementC& elt_d = d[j * Iterations::kW + i];
// The wmma instruction.
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
}
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef CUTLASS_USE_SUBBYTE_WMMA
/// Specialization for WMMA GEMM with unsigned 4-bit integer operands
template<typename WarpGemmShape_>
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
Vector<uint4_t, 8>,
MatrixLayout::kColumnMajor,
Vector<uint4_t, 8>,
MatrixLayout::kColumnMajor,
int,
WarpGemmShape_,
Shape<32, 8, 8> >{
/// The shape of the instruction.
typedef Shape<32, 8, 8> InstructionShape;
/// The number of threads per warp. That's a dummy configuration.
typedef Shape<1, 4, 8> ThreadsPerWarp;
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
typedef WarpGemmShape_ WarpGemmShape;
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
typedef WarpGemmShape_ AccumulatorsPerWarp;
/// The type for A.
typedef Vector<uint4_t, 8> ScalarA;
/// The type for B.
typedef Vector<uint4_t, 8> ScalarB;
/// The type for C and D.
typedef int ScalarC;
/// The number of iterations.
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
/// The element for A.
typedef WmmaMatrix<GemmOperand::kA,
MatrixLayout::kRowMajor,
Vector<uint4_t, 8>,
InstructionShape> ElementA;
/// The fragment for A.
typedef Fragment<ElementA, Iterations::kW> FragmentA;
/// The element for B.
typedef WmmaMatrix<GemmOperand::kB,
MatrixLayout::kColumnMajor,
Vector<uint4_t, 8>,
InstructionShape> ElementB;
/// The fragment for B.
typedef Fragment<ElementB, Iterations::kH> FragmentB;
/// The element for C.
typedef WmmaMatrix<GemmOperand::kC,
MatrixLayout::kColumnMajor,
int,
InstructionShape> ElementC;
/// The fragment for C.
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
/// Ctor.
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
/// Multiply : d = a*b.
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
FragmentB const& b,
Accumulators const& c,
Accumulators& d) {
for (int j = 0; j < Iterations::kH; ++j) {
for (int i = 0; i < Iterations::kW; ++i) {
// The input elements.
ElementA const& elt_a = a[i];
ElementB const& elt_b = b[j];
ElementC const& elt_c = c[j * Iterations::kW + i];
// The output element.
ElementC& elt_d = d[j * Iterations::kW + i];
// The wmma instruction.
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
}
}
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
#endif // defined CUTLASS_USE_WMMA_API

View File

@ -1,239 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterator traits for efficiently loading and storing fragment to and from shared
memory, specialized for WMMA GEMM.
*/
#pragma once
#include "cutlass/wmma_matrix.h"
#ifdef CUTLASS_USE_WMMA_API
#include "cutlass/gemm/gemm_operand.h"
#include "cutlass/reshape_tile.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Warps_,
int kWarpStride_,
typename Iterations_,
typename Delta_,
typename WmmaShape_>
struct WmmaGemmSharedLoadTileATraits {
/// The operand.
static GemmOperand::Kind const kOperand = GemmOperand::kA;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// The scalar.
typedef Scalar_ Scalar;
/// The pointer.
typedef Scalar const* Pointer;
/// The access size
static int const kAccessSize = 1;
/// The tile with skew.
typedef Tile_ Tile;
/// The number of warps.
typedef Warps_ Warps;
/// The warps strides.
static int const kWarpStride = kWarpStride_;
/// The number of iterations.
typedef Iterations_ Iterations;
/// The strides between iterations.
typedef Delta_ Delta;
/// The strides between iterations.
typedef Delta_ ImmediateOffsetStrides;
/// The shape of the WMMA instruction.
typedef WmmaShape_ WmmaShape;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// ThreadOffset
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// The warp id.
int const warp = threadIdx.x / kWarpSize;
// The offset.
int const offset = warp % Warps::kW * kWarpStride;
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename Tile_,
typename Warps_,
int kWarpStride_,
typename Iterations_,
typename Delta_,
typename WmmaShape_>
struct WmmaGemmSharedLoadTileBTraits {
/// The operand.
static GemmOperand::Kind const kOperand = GemmOperand::kB;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// The scalar.
typedef Scalar_ Scalar;
/// The pointer.
typedef Scalar const* Pointer;
/// The access size
static int const kAccessSize = 1;
/// The tile with skew.
typedef Tile_ Tile;
/// The number of warps.
typedef Warps_ Warps;
/// The warps strides.
static int const kWarpStride = kWarpStride_;
/// The number of iterations.
typedef Iterations_ Iterations;
/// The strides between iterations.
typedef Delta_ Delta;
/// The strides between iterations.
typedef Delta_ ImmediateOffsetStrides;
/// The shape of the WMMA instruction.
typedef WmmaShape_ WmmaShape;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// ThreadOffset
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// The warp id.
int const warp = threadIdx.x / kWarpSize;
// The offset.
int const offset = warp / Warps::kW * kWarpStride;
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MatrixLayout::Kind kLayout_,
typename Scalar_,
typename OutputTile_,
typename Warps_,
typename WmmaShape_,
int kSkew_ = 0>
struct WmmaGemmSharedStoreTileDTraits {
/// The operand.
static GemmOperand::Kind const kOperand = GemmOperand::kC;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// The scalar.
typedef Scalar_ Scalar;
// The access size
static int const kAccessSize = 1;
/// The pointer.
typedef Scalar* Pointer;
/// The number of warps.
typedef Warps_ Warps;
/// The shape of the WMMA instruction.
typedef WmmaShape_ WmmaShape;
/// The skew.
static int const kSkew = kSkew_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The tile with skew.
typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
/// The number of iterations needed to store the tile.
typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
/// ThreadOffset
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// The warp id.
int const warp = threadIdx.x / kWarpSize;
// The starting column.
int const h = warp / Warps::kW * WmmaShape::kH;
// The w.
int const w = warp % Warps::kW * WmmaShape::kW;
// The offset.
int const offset = h * Tile::kW + w;
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_, int kLdsPerAccess_ = 1>
struct WmmaGemmSharedLoadTileDTraits {
/// The scalar.
typedef Scalar_ Scalar;
/// The pointer.
typedef Scalar const* Pointer;
/// The access size
static int const kAccessSize = kScalarsPerLds_;
/// The tile.
typedef typename WmmaReshapeTile<Tile_, kScalarsPerLds_, kLdsPerAccess_>::Tile Tile;
/// The threads.
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
/// The threads strides.
typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_, kScalarsPerLds_>
ImmediateOffsetStrides;
/// The number of iterations needed to load/store the tile.
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
Iterations;
/// ThreadOffset
struct ThreadOffset {
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
// The offset.
int const offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
return make_Coord(0, 0, offset, 0);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
#endif // defined CUTLASS_USE_WMMA_API

File diff suppressed because it is too large Load Diff

View File

@ -1,101 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Free functions for loading and storing to implementations of tile iteartor concepts.
*/
#pragma once
#include "cutlass/load_store.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/shape.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
// Used by convolution
template <typename InputIterator, typename Fragment>
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
typename InputIterator::FragmentIterator frag_iterator(fragment);
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
if (iterator.valid(d, h, w, c)) {
iterator.load_element(reinterpret_cast<typename InputIterator::AccessType &>(
frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < InputIterator::Iterations::kW - 1) {
iterator.inc_w();
}
}
if (h < InputIterator::Iterations::kH - 1) {
iterator.inc_h();
}
}
if (d < InputIterator::Iterations::kD - 1) {
iterator.inc_d();
}
}
iterator.inc_advance();
}
template <typename OutputIterator, typename Fragment>
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
typename OutputIterator::FragmentIterator frag_iterator(fragment);
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
if (iterator.valid(d, h, w, c)) {
iterator.store_element(reinterpret_cast<typename OutputIterator::AccessType &>(
frag_iterator.at(d, h, w, c)),
d,
h,
w,
c);
}
}
if (w < OutputIterator::Iterations::kW - 1) {
iterator.inc_w();
}
}
if (h < OutputIterator::Iterations::kH - 1) {
iterator.inc_h();
}
}
if (d < OutputIterator::Iterations::kD - 1) {
iterator.inc_d();
}
}
iterator.inc_advance();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,379 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines abstractions for efficiently loading and storing vectors to memory.
*/
#pragma once
#include "cutlass/vector.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief Enum to specify which memory space data resides in.
*/
struct MemorySpace {
enum Kind {
kGeneric, // Data accessed through pointer dereferencing
kShared, // Data resides in shared memory
kGlobal // Data resides in global memory
};
};
/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix
struct FragmentElementType {
enum Kind { kScalar, kWmmaMatrix };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
int kAccessSize,
MemorySpace::Kind Memory_,
FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar,
typename FragmentElement_ = Scalar_,
int kStride = 1,
size_t size = (sizeof(Scalar_) * kAccessSize)>
struct Load {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
dst = *reinterpret_cast<AccessType const*>(pointer + offset);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for 16b loads
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
reinterpret_cast<uint16_t&>(dst) = reinterpret_cast<uint16_t const*>(&pointer[offset])[0];
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
dst.registers[0] = tmp.x;
dst.registers[1] = tmp.y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MemorySpace::Kind Memory_, int kStride>
struct Load<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
/// The output type.
typedef typename Vectorize<double, 2>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
dst[0] = tmp.x;
dst[1] = tmp.y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
// WAR bug in NVCC where the upper and lower half of the register end up being the same
template <MemorySpace::Kind Memory_, int kStride>
struct Load<half, 8, Memory_, FragmentElementType::kScalar, half, kStride, 16> {
/// The output type.
typedef typename Vectorize<half, 8>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
dst.registers[0] = tmp.x;
dst.registers[1] = tmp.y;
tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
dst.registers[2] = tmp.x;
dst.registers[3] = tmp.y;
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
dst.registers[0] = tmp.x;
dst.registers[1] = tmp.y;
dst.registers[2] = tmp.z;
dst.registers[3] = tmp.w;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
int kAccessSize,
MemorySpace::Kind Memory_,
FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar,
typename FragmentElement_ = Scalar_,
int kStride = 1,
size_t size = (sizeof(Scalar_) * kAccessSize)>
struct Store {
/// The output type.
typedef typename Vectorize<FragmentElement_, kAccessSize>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
pointer[offset] = *reinterpret_cast<Scalar_ const*>(&src);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
uint16_t* addr = reinterpret_cast<uint16_t*>(&pointer[offset]);
addr[0] = reinterpret_cast<uint16_t const&>(src);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
addr[0] = src.registers[0];
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
addr[0] = make_uint2(src.registers[0], src.registers[1]);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <MemorySpace::Kind Memory_, int kStride>
struct Store<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
/// The output type.
typedef typename Vectorize<double, 2>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, double* pointer, int offset) {
double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
addr[0] = make_double2(src[0], src[1]);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
/// The output type.
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
int kAccessSize,
MemorySpace::Kind Memory_,
typename FragmentElement_,
int kStride,
size_t size>
struct Load<Scalar_,
kAccessSize,
Memory_,
FragmentElementType::kWmmaMatrix,
FragmentElement_,
kStride,
size> {
/// The output type.
typedef FragmentElement_ AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
value.load(&pointer[offset], kStride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kAccessSize,
MemorySpace::Kind Memory_,
typename FragmentElement_,
int kStride,
size_t size>
struct Load<Vector<bin1_t, 32>,
kAccessSize,
Memory_,
FragmentElementType::kWmmaMatrix,
FragmentElement_,
kStride,
size> {
/// The output type.
typedef FragmentElement_ AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<bin1_t, 32> const* pointer,
int offset) {
value.load(&pointer[offset], kStride * 32);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kAccessSize,
MemorySpace::Kind Memory_,
typename FragmentElement_,
int kStride,
size_t size>
struct Load<Vector<int4_t, 8>,
kAccessSize,
Memory_,
FragmentElementType::kWmmaMatrix,
FragmentElement_,
kStride,
size> {
/// The output type.
typedef FragmentElement_ AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<int4_t, 8> const* pointer,
int offset) {
value.load(&pointer[offset], kStride * 8);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kAccessSize,
MemorySpace::Kind Memory_,
typename FragmentElement_,
int kStride,
size_t size>
struct Load<Vector<uint4_t, 8>,
kAccessSize,
Memory_,
FragmentElementType::kWmmaMatrix,
FragmentElement_,
kStride,
size> {
/// The output type.
typedef FragmentElement_ AccessType;
/// The load function.
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<uint4_t, 8> const* pointer,
int offset) {
value.load(&pointer[offset], kStride * 8);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_,
int kAccessSize,
MemorySpace::Kind Memory_,
typename FragmentElement_,
int kStride,
size_t size>
struct Store<Scalar_,
kAccessSize,
Memory_,
FragmentElementType::kWmmaMatrix,
FragmentElement_,
kStride,
size> {
/// The input type.
typedef FragmentElement_ AccessType;
/// The store function.
static CUTLASS_HOST_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
value.store(&pointer[offset], kStride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,372 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines properties of matrices used to denote layout and operands to GEMM kernels.
*/
#pragma once
#include "cutlass/coord.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.
struct MatrixCoord : public Coord<2, int> {
/// Integer-valued index
typedef int Index;
/// Base type is a Coord of rank=2
typedef Coord<2, Index> Base;
/// Rows dimension
static int const kRow = 0;
/// Columns dimension
static int const kColumn = 1;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
MatrixCoord() { }
/// Constructs from Coord<2>
CUTLASS_HOST_DEVICE
MatrixCoord(Coord<2, Index> const &coord): Base(coord) { }
/// Helper to construct from a row and column
CUTLASS_HOST_DEVICE
MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { }
/// Returns the row of the coordinate
CUTLASS_HOST_DEVICE
Index const & row() const { return this->at(kRow); }
/// Returns the row of the coordinate
CUTLASS_HOST_DEVICE
Index & row() { return this->at(kRow); }
/// Returns the column of the coordinate
CUTLASS_HOST_DEVICE
Index const & column() const { return this->at(kColumn); }
/// Returns the column of the coordinate
CUTLASS_HOST_DEVICE
Index & column() { return this->at(kColumn); }
//
// Coord operators
//
/// Element-wise addition
CUTLASS_HOST_DEVICE
MatrixCoord operator+(Base const& b) const {
return MatrixCoord(Base::operator+(b));
}
/// Element-wise subtraction
CUTLASS_HOST_DEVICE
MatrixCoord operator-(Base const& b) const {
return MatrixCoord(Base::operator-(b));
}
/// Element-wise multiplication
CUTLASS_HOST_DEVICE
MatrixCoord operator*(Base const& b) const {
return MatrixCoord(Base::operator*(b));
}
/// Element-wise division
CUTLASS_HOST_DEVICE
MatrixCoord operator/(Base const& b) const {
return MatrixCoord(Base::operator/(b));
}
/// In-place addition
CUTLASS_HOST_DEVICE
MatrixCoord& operator+=(Base const& b) {
Base::operator+=(b);
return *this;
}
/// In-place subtraction
CUTLASS_HOST_DEVICE
MatrixCoord& operator-=(Base const& b) {
Base::operator-=(b);
return *this;
}
/// In-place multiplication
CUTLASS_HOST_DEVICE
MatrixCoord& operator*=(Base const& b) {
Base::operator*=(b);
return *this;
}
/// In-place division
CUTLASS_HOST_DEVICE
MatrixCoord& operator/=(Base const& b) {
Base::operator/=(b);
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines data layouts of various matrix formats usable by TensorRef and other classes.
//
// The following define classes satisfying the TensorRefMapFunc concept. These must support the
// following operations, where func is an instance of type TensorRefMapFunc.
//
// Coord<TensorRefMapFunc::kStorageRank> = func(Coord<kRank>);
//
// Though not required to be usable by TensorRef, each of the following also define a helper
// function to map the "leading dimension" to an appropriate stride vector. Implementations
// following this convention should also implement the following static method:
//
// Coord<TensorRefMapFunc::kStorageRank> stride = TensorRefMapFunc::stride(leading_dim);
//
namespace MatrixLayout {
/// Enumeration defining fundamental contiguous layouts.
enum Kind { kRowMajor, kColumnMajor };
//
// TensorRefMapFunc definitions for common layouts
//
/// Mapping function for row-major matrices
struct RowMajor {
static int const kStorageRank = 2;
/// Maps (i, j) to (i, j)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return coord;
}
};
/// Mapping function for column-major matrices
struct ColumnMajor {
static int const kStorageRank = 2;
/// Maps (i, j) to (j, i)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(coord.column(), coord.row());
}
};
/// Mapping function for interleaved matrices. Matrix is structured
/// as row-major arrangement of fixed-size columns.
template <int Interleave>
struct RowMajorInterleaved {
/// Rank of storage n-D array
static int const kStorageRank = 3;
/// Interleaving size
static int const kInterleave = Interleave;
/// Maps (row, col) to (row, col, row)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(
coord.row() / kInterleave,
coord.column(),
coord.row() % kInterleave
);
}
/// Helper to compute stride vector from leading dimension
CUTLASS_HOST_DEVICE
static Coord<kStorageRank> stride(int ldm) {
return make_Coord(
ldm * kInterleave,
kInterleave,
1
);
}
};
/// Mapping function for interleaved matrices. Matrix is structured
/// as column-major arrangement of fixed-size rows.
template <int Interleave>
struct ColumnMajorInterleaved {
/// Rank of storage n-D array
static int const kStorageRank = 3;
/// Interleaving size
static int const kInterleave = Interleave;
/// Maps (row, col) to (col, row, col)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(
coord.column() / kInterleave,
coord.row(),
coord.column() % kInterleave
);
}
/// Helper to compute stride vector from leading dimension
CUTLASS_HOST_DEVICE
static Coord<kStorageRank> stride(int ldm) {
return make_Coord(
ldm * kInterleave,
kInterleave,
1
);
}
};
/// Mapping function for scenario in which layout is row-major or column-major but this information
/// is only available at runtime.
struct ContiguousLayout {
/// Arbitrary storage rank
static int const kStorageRank = 3;
/// Dimension of rows
static int const kRow = 0;
/// Dimension of columns
static int const kColumn = 1;
/// Mapping function defined by runtime variable. Returns coordinates in n-D storage array
/// as (matrix row, matrix colum, 0)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(coord.row(), coord.column(), 0);
}
/// Helper to construct a stride vector based on contiguous matrix layout and leading dimension
CUTLASS_HOST_DEVICE
static Coord<kStorageRank> stride(MatrixLayout::Kind layout, int ldm) {
if (layout == MatrixLayout::kRowMajor) {
return make_Coord(ldm, 1, 1);
}
return make_Coord(1, ldm, 1);
}
};
/// Mapping function for block-linear matrices. Matrix is structured
/// as column-major arrangement of 2D tiles (that are column-major).
template <int BlockRows, int BlockColumns>
struct ColumnMajorBlockLinear {
/// Rank of storage n-D array
static int const kStorageRank = 4;
/// Interleaving size in rows dimension
static int const kBlockRows = BlockRows;
/// Interleaving size in columns dimension
static int const kBlockColumns = BlockColumns;
/// Maps (row, col) to (col, row, col, row)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(
coord.column() / kBlockColumns,
coord.row() / kBlockRows,
coord.column() % kBlockColumns,
coord.row() % kBlockRows
);
}
/// Helper to compute stride vector from leading dimension
CUTLASS_HOST_DEVICE
static Coord<kStorageRank> stride(int ldm) {
return make_Coord(
ldm * kBlockRows * kBlockColumns,
kBlockRows * kBlockColumns,
kBlockRows,
1
);
}
};
/// Mapping function for block-linear matrices. Matrix is structured
/// as row-major arrangement of 2D tiles (that are row-major)
template <int BlockRows, int BlockColumns>
struct RowMajorBlockLinear {
/// Rank of storage n-D array
static int const kStorageRank = 4;
/// Interleaving size in rows dimension
static int const kBlockRows = BlockRows;
/// Interleaving size in columns dimension
static int const kBlockColumns = BlockColumns;
/// Maps (row, col) to (row, col, row, col)
CUTLASS_HOST_DEVICE
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
return make_Coord(
coord.row() / kBlockRows,
coord.column() / kBlockColumns,
coord.row() % kBlockRows,
coord.column() % kBlockColumns
);
}
/// Helper to compute stride vector from leading dimension
CUTLASS_HOST_DEVICE
static Coord<kStorageRank> stride(int ldm) {
return make_Coord(
ldm * kBlockRows * kBlockColumns,
kBlockRows * kBlockColumns,
kBlockColumns,
1
);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Gemm operand - D = A * B + C
struct GemmOperand {
enum Kind { kA, kB, kC, kD };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Transformation applied to matrix operands
struct MatrixTransform {
enum Kind {
kNone, /// no operation
kConjugate, /// conjugate
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,61 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation.
*/
#pragma once
#include "cutlass/coord.h"
namespace cutlass {
namespace reduction {
struct DefaultBlockSwizzle {
/// Ctor
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
/// Swizzle the block index.
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
///
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
Coord<3> const &OutputTile) {
assert(OutputTile[0] == 1 && OutputTile[1] == 1);
assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0);
dim3 grid;
grid.x = problem_size[0] * problem_size[1] * problem_size[2]
/ OutputTile[2] ;
return grid;
}
///
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) {
assert(SubTile[0] == 1 && SubTile[1] == 1);
dim3 block = swizzle();
Coord<3> threadblock_offset =
make_Coord(0, 0, block.x * SubTile[2]);
return threadblock_offset;
}
};
} // namespace reduction
} // namespace cutlass

View File

@ -1,74 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a type for restructuring a tile.
*/
#pragma once
#include "cutlass/shape.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functor reshapes a tile of data. The goal is to have at least kAccessSize in
// the inner-most dimension. If the user respects that constraint, there is nothing to be done. If
// that's not the case, this functor will correct that and "extract" the right number of elements
// from the next dimension.
template <typename Tile_, int kAccessSize_, bool = (Tile_::kC < kAccessSize_)>
struct ReshapeTile {
typedef Tile_ Tile;
};
template <typename Tile_, int kAccessSize_>
struct ReshapeTile<Tile_, kAccessSize_, true> {
// Make sure the W dimension of the tile is large enough.
static_assert(Tile_::kW >= kAccessSize_, "The W dimension is too small");
// Make sure the dimension can be divided by the number of scalars.
static_assert(Tile_::kW % kAccessSize_ == 0, "Not supported");
// Collapse the W dimension.
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / kAccessSize_, kAccessSize_> Tile;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Tile_, int kAccessSize_, int kLdsPerAccess_, bool = (Tile_::kC < (kAccessSize_ * kLdsPerAccess_))>
struct WmmaReshapeTile {
typedef Tile_ Tile;
};
template <typename Tile_, int kAccessSize_, int kLdsPerAccess_>
struct WmmaReshapeTile<Tile_, kAccessSize_, kLdsPerAccess_, true> {
// Make sure the W dimension of the tile is large enough.
static_assert(Tile_::kW >= (kAccessSize_ * kLdsPerAccess_), "The W dimension is too small");
// Make sure the dimension can be divided by the number of scalars.
static_assert(Tile_::kW % (kAccessSize_ * kLdsPerAccess_) == 0, "Not supported");
// Collapse the W dimension.
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / (kAccessSize_ * kLdsPerAccess_), (kAccessSize_ * kLdsPerAccess_)> Tile;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,262 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
*/
#pragma once
#include "cutlass/cutlass.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup layout_concept Layout Concept
* @{
* @par Implementations of \ref layout_concept are used to describe a cube with DxHxW elements and C
scalars per element.
A HxW slice of a cube is called an image and a cube consists of D images.
*
* @par Notations
* Let Layout be an implementation of the \ref layout_concept.
*
* @par Valid Expressions
* - <b>Layout::D</b> specifies the depth of a cube
* - <b>Layout::H</b> specifies the height of a cube
* - <b>Layout::W</b> specifies the height of a cube
* - <b>Layout::C</b> specifies the number of channels of each element in a cube
* - <b>Layout::W_c</b> specifies the number of scalars of each row in one image of a cube.
* - <b>Layout::H_w</b> specifies the number of elements in an image slice.
* - <b>Layout::H_w_c</b>_specifies the number of scalars in an image slice.
* - <b>Layout::D_h_w</b> specifies the number of elements in a cube.
* - <b>Layout::D_h_w_c</b> specifies the number of scalars in a cube.
* - <b>Layout::Strides</b> is a \ref layout_concept specifying the strides.
* @}
*/
/**
* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube.
* @concept{layout_concept}
*/
template <int kD_ = 1, int kH_ = 1, int kW_ = 1, int kC_ = 1>
struct Shape {
/// The depth of the cube.
static int const kD = kD_;
/// The height of the cube.
static int const kH = kH_;
/// The width of the cube.
static int const kW = kW_;
/// The number of scalars per element.
static int const kC = kC_;
};
/**
* @brief Compute derived counted of a \ref layout_concept based class
*/
template <typename Shape>
struct ShapeCount {
/// The number of elements per row.
static int const kWc = Shape::kW * Shape::kC;
/// The number of pixels per image.
static int const kHw = Shape::kH * Shape::kW;
/// The number of elements per image.
static int const kHwc = Shape::kH * kWc;
/// The number of pixels per cube.
static int const kDhw = Shape::kD * kHw;
/// The number of elements in the 4D space.
static int const kDhwc = Shape::kD * kHwc;
/// The number of elements in the 4D space.
static int const kCount = kDhwc;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, int kScale_>
struct ShapeScale {
typedef Shape<A_::kD * kScale_, A_::kH * kScale_, A_::kW * kScale_, A_::kC * kScale_> Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeAdd {
typedef Shape<A_::kD + B_::kD, A_::kH + B_::kH, A_::kW + B_::kW, A_::kC + B_::kC> Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeSub {
typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC> Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeMul {
typedef Shape<A_::kD * B_::kD, A_::kH * B_::kH, A_::kW * B_::kW, A_::kC * B_::kC> Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeDiv {
typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC> Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeDivCeiling {
typedef Shape<(A_::kD + B_::kD - 1) / B_::kD,
(A_::kH + B_::kH - 1) / B_::kH,
(A_::kW + B_::kW - 1) / B_::kW,
(A_::kC + B_::kC - 1) / B_::kC>
Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeMax {
typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
(A_::kH > B_::kH ? A_::kH : B_::kH),
(A_::kW > B_::kW ? A_::kW : B_::kW),
(A_::kC > B_::kC ? A_::kC : B_::kC)>
Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename A_, typename B_>
struct ShapeMin {
typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
(A_::kH < B_::kH ? A_::kH : B_::kH),
(A_::kW < B_::kW ? A_::kW : B_::kW),
(A_::kC < B_::kC ? A_::kC : B_::kC)>
Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Shape_, int elementsPerAccess>
struct ShapeStrides {
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
Shape_::kW * Shape_::kC,
Shape_::kC,
elementsPerAccess>
Shape;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief Compute the offset for the given coordinates in a cube
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
*/
template <typename Shape_>
struct ComputeOffsetFromShape {
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
// clang-format off
return d * Shape_::kH * Shape_::kW * Shape_::kC +
h * Shape_::kW * Shape_::kC +
w * Shape_::kC +
c;
// clang-format on
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief Compute the offset for the given coordinates in a cube
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
*/
template <typename Strides_>
struct ComputeOffsetFromStrides {
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_.
* Afterwards compute the offset of those coordinates using Strides_
* @tparam Threads_ The dimension of the cube the threadIdx.x value is mapped on
* @tparam Strides_ The strides to use when compute the offsets based on the coordinates of the cube.
*/
template <typename Threads_, typename Strides_>
struct ComputeThreadOffsetFromStrides {
static CUTLASS_DEVICE int get() {
// Decompose the thread index.
int c = threadIdx.x % Threads_::kC;
int w = threadIdx.x / Threads_::kC % Threads_::kW;
int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
// Compute the offset.
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
*@brief Specialization for D=1
*/
template <int T_h_, int T_w_, int T_c_, int S_h_, int S_w_, int S_c_>
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, T_c_>, Shape<1, S_h_, S_w_, S_c_> > {
static CUTLASS_DEVICE int get() {
// Decompose the thread index.
int c = threadIdx.x % T_c_;
int w = threadIdx.x / T_c_ % T_w_;
int h = threadIdx.x / T_c_ / T_w_ % T_h_;
// Compute the offset.
return h * S_h_ + w * S_w_ + c * S_c_;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
*@brief Specialization for D=1 and C=1
*/
template <int T_h_, int T_w_, int S_h_, int S_w_>
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, 1>, Shape<1, S_h_, S_w_, 1> > {
static CUTLASS_DEVICE int get() {
// Decompose the thread index.
int w = threadIdx.x % T_w_;
int h = threadIdx.x / T_w_;
// Compute the offset.
return h * S_h_ + w * S_w_;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,639 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/vector.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Default mapping function from coordinates in a tensor's index space into the n-D array held
/// in memory. Assumes StorageRank = Rank
template <int Rank>
struct IdentityTensorMapFunc {
static int const kStorageRank = Rank;
CUTLASS_HOST_DEVICE
Coord<Rank> operator()(Coord<Rank> const &coord) const {
return coord;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/* \brief Structure modeling a pointer and stride into a tensor.
A tensor consists of an index space with Rank_ dimensions. It is stored in memory modeled
as an n-D array, where n = StorageRank_. A mapping function maps the logical coordinates of the
tensor's index space into the n-D array, and a stride vector maps the n-D array to linear memory.
CUTLASS requires the n-D array's least significant, "fastest changing" dimension to
be contiguous in memory. It therefore has a stride of 1 and is not stored. Construction is offered
from vectors of full StorageRank and of the 'compact' rank, though it is in error to construct
with the least significant stride != 1.
The requirement that the least significant dimension be consecutive enables numerous optimizations
and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various
BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2
matrix is provided.
This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all
stride elements are > 1. This can be overcome by defining a custom mapping function and a
StorageRank of 3 or more.
Examples:
(These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h)
1. Column-major matrix may be represented as a rank=2 tensor:
TensorRef<float, 2, MatrixLayout::ColumnMajor> A(ptr_A, make_Coord(ldm, 1));
2. Row-major matrix may be represented as a rank=2 tensor:
TensorRef<float, 2, MatrixLayout::RowMajor> B(ptr_A, ldm);
3. An interleaved matrix may be represented as a rank=2 tensor:
TensorRef<int8_t, 2, MatrixLayout::ColumnMajorInterleaved<32> > C;
4. Defining a sparse matrix with arbitrary strides in each dimension
struct ContiguousLayout {
/// Arbitrary storage rank
static int const kStorageRank = 3;
/// Mapping function defined by runtime stride configuration
CUTLASS_HOST_DEVICE
Coord<3> operator()(MatrixCoord const &coord) const {
return make_Coord(coord.row(), coord.column(), 0);
}
};
typedef TensorRef<float, 2, ContiguousLayout> ContiguousTensorRef;
// Construct the TensorRef object from a pair of stride values
ContiguousTensorRef D(ptr_D, make_Coord(row_stride, column_stride));
5. A helper exists to define a TensorRef for a contiguous matrix whose layout
is not known at compile time.
MatrixLayout::Kind layout; // Could be MatrixLayout::kRowMajor or MatrixLayout::kColumnMajor
int ldm; // leading dimension
ContiguousTensorRef E(ptr_E, ContiguousLayout::stride(layout, ldm));
*/
template <
/// Data type of element stored within tensor
typename Storage_,
/// Rank of logical tensor
int Rank_,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
/// Rank of internal n-D array
int StorageRank_ = MapFunc_::kStorageRank,
/// Index type used for coordinates
typename Index_ = int,
/// Index type used for offsets and pointer differences
typename LongIndex_ = long long
>
class TensorRef {
public:
/// Data type of individual access
typedef Storage_ Storage;
/// Logical rank of tensor index space
static int const kRank = Rank_;
/// Mapping function from logical coordinate to internal n-D array
typedef MapFunc_ MapFunc;
/// Rank of internal storage
static int const kStorageRank = StorageRank_;
/// Index type
typedef Index_ Index;
/// Typically, strides in memory can be very large
typedef LongIndex_ LongIndex;
/// Coordinate in logical tensor space
typedef Coord<kRank> TensorCoord;
/// Coordinate in storage n-D array
typedef Coord<kStorageRank> StorageCoord;
/// Stride vector in storage coordinage space - assumes least significant stride
/// is 1 and does not store it.
typedef Coord<kStorageRank - 1> StrideVector;
/// Tensor reference to of constant value
typedef TensorRef<
typename platform::remove_const<Storage>::type const,
Rank_,
MapFunc_,
StorageRank_,
Index_,
LongIndex_> ConstTensorRef;
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
/// scalar, but degenerate cases such as these are difficult to accommodate without
/// extensive C++ metaprogramming or support for zero-length arrays.
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
//
// Definitions included for backwards compatibility - to be removed in next major release
//
/// Coordinate in logical tensor space
typedef TensorCoord Coord_t;
/// Logical rank of tensor index space
static int const Rank = kRank;
private:
/// Pointer
Storage* ptr_;
/// Stride vector - fastest-changing stride assumed to be 1 and not stored
StrideVector stride_;
/// Maps a logical coordinate to an n-D array's tensor space
MapFunc coord_map_;
public:
//
// Methods
//
/// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank.
CUTLASS_HOST_DEVICE
TensorRef(Storage *ptr = nullptr): ptr_(ptr) {
for (int i = 0; i < kStorageRank - 1; ++i) {
stride_[i] = 1;
}
}
/// Helper to construct from a pointer and single stride element for 2-D pitch linear memory.
// Higher ranks are projected onto the fastest-changing rank.
CUTLASS_HOST_DEVICE
TensorRef(Storage* ptr, Index ldm) {
ptr_ = ptr;
for (int i = 0; i < kStorageRank - 1; ++i) {
stride_[i] = ldm;
}
}
/// Constructs from a single pointer and stride vector
CUTLASS_HOST_DEVICE
TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr), stride_(stride) {
}
/// Constructs from a pointer and a stride vector of size kRank. If fastest changing
/// stride is not 1, construction fails and subsequent calls to good() will return false.
CUTLASS_HOST_DEVICE
TensorRef(Storage* ptr, StorageCoord const& stride) {
// Fastest-changing stride must be one
if (stride.at(kStorageRank - 1) == 1) {
ptr_ = ptr;
for (int i = 0; i < kStorageRank - 1; ++i) {
stride_[i] = stride[i];
}
}
else {
// Fastest-chaning stride must be 1.
reset();
}
}
/// Enables conversion from TensorRef of non-const type
CUTLASS_HOST_DEVICE
TensorRef(
TensorRef<
typename platform::remove_const<Storage>::type,
kRank,
MapFunc,
kStorageRank,
Index,
LongIndex> const &ref
):
ptr_(ref.data()) {
for (int i = 0; i < kStorageRank - 1; ++i) {
stride_[i] = ref.stride(i);
}
}
/// Returns a reference to constant-valued tensor
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(*this);
}
/// Updates only the pointer
CUTLASS_HOST_DEVICE
void reset(Storage* ptr = nullptr) {
ptr_ = ptr;
}
/// Updates the pointer, stride, and location within a TensorRef
CUTLASS_HOST_DEVICE
void reset(Storage* ptr, StorageCoord const & stride) {
// Fastest-changing stride must be one
if (stride.at(kStorageRank - 1) == 1) {
ptr_ = ptr;
for (int i = 0; i < kStorageRank - 1; ++i) {
stride_[i] = stride[i];
}
}
else {
// Fastest-changing stride must be 1 - this is an error.
reset();
}
}
/// Returns true if the TensorRef may be safely accessed
CUTLASS_HOST_DEVICE
bool good() const {
return ptr_ != nullptr;
}
/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Storage * data() const { return ptr_; }
/// Returns the stride of the tensor
CUTLASS_HOST_DEVICE
StorageCoord stride() const {
StorageCoord ld;
for (int i = 0; i < kStorageRank - 1; ++i) {
ld[i] = stride_[i];
}
ld[kStorageRank - 1] = 1;
return ld;
}
/// Returns the stride of the tensor in the given dimension
CUTLASS_HOST_DEVICE
Index stride(int dim) const {
// fastest-changing stride assumbed to be 1
if (dim + 1 >= kStorageRank) {
return 1;
}
return stride_.at(dim);
}
/// Returns the maximum stride element as the 'leading dimension'
CUTLASS_HOST_DEVICE
Index leading_dim(int idx = 0) const { return stride(idx); }
/// Maps a logical coordinate to an n-D array in memory
CUTLASS_HOST_DEVICE
StorageCoord map(TensorCoord const &coord) const {
return coord_map_(coord);
}
/// Computes the offset of an index from the origin of the tensor
CUTLASS_HOST_DEVICE
LongIndex offset(TensorCoord const& coord) const {
return stride().template dot<LongIndex>(map(coord));
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Storage& at(TensorCoord const& coord) const {
return ptr_[offset(coord)];
}
/// Returns a reference to the element at a given linear index
CUTLASS_HOST_DEVICE
Storage& at(LongIndex idx) const { return ptr_[idx]; }
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Storage& operator[](TensorCoord const& coord) const {
return ptr_[offset(coord)];
}
/// Returns a reference to the element at a given linear index
CUTLASS_HOST_DEVICE
Storage& operator[](LongIndex idx) const { return ptr_[idx]; }
/// Adds an offset to each pointer
CUTLASS_HOST_DEVICE
TensorRef & add_pointer_offset(LongIndex delta) {
ptr_ += delta;
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef operator+(TensorCoord const& b) const {
TensorRef result(*this);
result.add_pointer_offset(offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef& operator+=(TensorCoord const& b) {
add_pointer_offset(offset(b));
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef operator-(TensorCoord const& b) const {
TensorRef result(*this);
result.add_pointer_offset(-offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef& operator-=(TensorCoord const& b) {
add_pointer_offset(-offset(b));
return *this;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations to handle degenerate cases.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Specialization for rank=1 case with no internal StrideVector
template <
/// Data type of element stored within tensor
typename Storage_,
/// Rank of logical tensor
int Rank_,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename MapFunc_,
/// Index type used for coordinates
typename Index_,
/// Index type used for offsets and pointer differences
typename LongIndex_
>
class TensorRef<Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_> {
public:
/// Data type of individual access
typedef Storage_ Storage;
/// Logical rank of tensor index space
static int const kRank = Rank_;
/// Mapping function from logical coordinate to internal n-D array
typedef MapFunc_ MapFunc;
/// Rank of internal storage
static int const kStorageRank = 1;
/// Index type
typedef Index_ Index;
/// Typically, strides in memory can be very large
typedef LongIndex_ LongIndex;
/// Coordinate in logical tensor space
typedef Coord<kRank> TensorCoord;
/// Coordinate in storage n-D array
typedef Coord<kStorageRank> StorageCoord;
/// Stride vector in storage coordinage space - assumes least significant stride
/// is 1 and does not store it.
struct StrideVector { };
/// Tensor reference to of constant value
typedef TensorRef<
typename platform::remove_const<Storage>::type const,
Rank_,
MapFunc_,
kStorageRank,
Index_,
LongIndex_> ConstTensorRef;
//
// Definitions included for backwards compatibility - to be removed in next major release
//
/// Coordinate in logical tensor space
typedef TensorCoord Coord_t;
/// Logical rank of tensor index space
static int const Rank = kRank;
private:
/// Pointer
Storage* ptr_;
/// Maps a logical coordinate to an n-D array's tensor space
MapFunc coord_map_;
public:
//
// Methods
//
/// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank.
CUTLASS_HOST_DEVICE
TensorRef(Storage *ptr = nullptr): ptr_(ptr) { }
/// Constructs from a single pointer and stride vector
CUTLASS_HOST_DEVICE
TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr) {
}
/// Constructs from a pointer and a stride vector of size kRank. If fastest changing
/// stride is not 1, construction fails and subsequent calls to good() will return false.
CUTLASS_HOST_DEVICE
TensorRef(Storage* ptr, StorageCoord const& stride) {
// Fastest-changing stride must be one
if (stride.at(kStorageRank - 1) == 1) {
ptr_ = ptr;
}
else {
// Fastest-chaning stride must be 1.
reset();
}
}
/// Enables conversion from TensorRef of non-const type
CUTLASS_HOST_DEVICE
TensorRef(
TensorRef<
typename platform::remove_const<Storage>::type,
kRank,
MapFunc,
kStorageRank,
Index,
LongIndex> const &ref
):
ptr_(ref.data()) {
}
/// Returns a reference to constant-valued tensor
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(*this);
}
/// Updates only the pointer
CUTLASS_HOST_DEVICE
void reset(Storage* ptr = nullptr) {
ptr_ = ptr;
}
/// Updates the pointer, stride, and location within a TensorRef
CUTLASS_HOST_DEVICE
void reset(Storage* ptr, StorageCoord const & stride) {
// Fastest-changing stride must be one
if (stride.at(kStorageRank - 1) == 1) {
ptr_ = ptr;
}
else {
// Fastest-changing stride must be 1 - this is an error.
reset();
}
}
/// Returns true if the TensorRef may be safely accessed
CUTLASS_HOST_DEVICE
bool good() const {
return ptr_ != nullptr;
}
/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Storage * data() const { return ptr_; }
/// Returns the stride of the tensor
CUTLASS_HOST_DEVICE
StorageCoord stride() const {
StorageCoord ld;
ld[kStorageRank - 1] = 1;
return ld;
}
/// Returns the stride of the tensor in the given dimension
CUTLASS_HOST_DEVICE
Index stride(int dim) const {
// fastest-changing stride assumbed to be 1
return 1;
}
/// Returns the maximum stride element as the 'leading dimension'
CUTLASS_HOST_DEVICE
Index leading_dim(int idx = 0) const { return 1; }
/// Maps a logical coordinate to an n-D array in memory
CUTLASS_HOST_DEVICE
StorageCoord map(TensorCoord const &coord) const {
return coord_map_(coord);
}
/// Computes the offset of an index from the origin of the tensor
CUTLASS_HOST_DEVICE
LongIndex offset(TensorCoord const& coord) const {
return stride().template dot<LongIndex>(map(coord));
}
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Storage& at(TensorCoord const& coord) const {
return ptr_[offset(coord)];
}
/// Returns a reference to the element at a given linear index
CUTLASS_HOST_DEVICE
Storage& at(LongIndex idx) const { return ptr_[idx]; }
/// Returns a reference to the element at a given Coord
CUTLASS_HOST_DEVICE
Storage& operator[](TensorCoord const& coord) const {
return ptr_[offset(coord)];
}
/// Returns a reference to the element at a given linear index
CUTLASS_HOST_DEVICE
Storage& operator[](LongIndex idx) const { return ptr_[idx]; }
/// Adds an offset to each pointer
CUTLASS_HOST_DEVICE
TensorRef & add_pointer_offset(LongIndex delta) {
ptr_ += delta;
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef operator+(TensorCoord const& b) const {
TensorRef result(*this);
result.add_pointer_offset(offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef& operator+=(TensorCoord const& b) {
add_pointer_offset(offset(b));
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef operator-(TensorCoord const& b) const {
TensorRef result(*this);
result.add_pointer_offset(-offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorRef& operator-=(TensorCoord const& b) {
add_pointer_offset(-offset(b));
return *this;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,449 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray.
*/
#pragma once
#include "cutlass/tensor_ref.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// TensorRefCollection is a concept for storing a logical collection of TensorRef objects. Classes
// satisfying the TensorRefCollection concept must support the following:
//
// // Define storage type
// typedef typename TensorRefCollection::Storage Storage;
//
// // Define a type for offsets in memory
// typedef typename TensorRefCollection::LongIndex LongIndex;
//
// // Define a ConstIterator type satisfying TensorRefIterator
// typedef typename TensorRefCollection::ConstIterator TensorRefIterator;
//
// // Implement a begin() method.
// TensorRefIterator iterator = collection.begin();
//
//
// TensorRefIterator is a concept for accessing an element in a TensorRefCollection. Classes
// satisfying the TensorRefIterator concept must support the following:
//
// // Define a TensorRef type accessed by the iterator
// typedef typename TensorRefIterator::TensorRef TensorRef;
//
// // Access the TensorRef
// TensorRef ref = *iterator;
//
// // Pre-increment and post-increment
// ++iterator;
// iterator++;
//
// // Pre-decrement and post-decrement
// --iterator;
// iterator--;
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// This satisfies TensorRefCollection and stores a collection of TensorRef objects that
/// have identical strides. TensorRef objects are separated by a linear stride.
template <
/// Data type of element stored within tensor
typename Storage_,
/// Rank of logical tensor
int Rank_,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
/// Rank of internal n-D array
int StorageRank_ = MapFunc_::kStorageRank,
/// Index type used for coordinates
typename Index_ = int,
/// Index type used for offsets and pointer differences
typename LongIndex_ = long long
>
struct TensorRefBatchStrided:
public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
//
// Type definitions
//
/// Underlying TensorRef type
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> Base;
/// Storage type
typedef typename Base::Storage Storage;
/// Rank of the logical tensor
static int const kRank = Rank_;
/// Index type
typedef Index_ Index;
/// Typically, strides in memory can be very large
typedef LongIndex_ LongIndex;
/// Coordinate in logical tensor space
typedef Coord<kRank> TensorCoord;
/// Tensor reference implied by the TensorRefBatchStrided
typedef Base TensorRef;
/// Constant iterator over tensors implied by TensorRefBatchStrided
class ConstIterator {
public:
/// TensorRef returned by the iterator
typedef Base TensorRef;
private:
/// Reference to the parent TensorBatchRef object
TensorRefBatchStrided const &ref_;
/// Offset from the base TensorRef pointer
LongIndex offset_;
public:
/// Constructs a ConstIterator from a parent TensorRefBatchStrided
CUTLASS_HOST_DEVICE
ConstIterator(
TensorRefBatchStrided const &ref,
LongIndex offset = 0): ref_(ref), offset_(offset) { }
/// Obtains a TensorRef pointed to by the iterator
CUTLASS_HOST_DEVICE
TensorRef operator*() const {
TensorRef ref(ref_);
ref.add_pointer_offset(offset_);
return ref;
}
/// Advances the iterator to point to the next tensor
CUTLASS_HOST_DEVICE
ConstIterator &operator++() {
offset_ += ref_.tensor_stride;
return *this;
}
/// Advances the iterator to point to the next tensor
CUTLASS_HOST_DEVICE
ConstIterator operator++(int) {
ConstIterator ret(*this);
offset_ += ref_.tensor_stride;
return ret;
}
/// Returns an iterator advanced by (idx) amount
CUTLASS_HOST_DEVICE
ConstIterator operator+(Index idx) {
return ConstIterator(ref_, offset_ + ref_.tensor_stride * idx);
}
/// Advances this iterator by (idx) and returns a reference to self
CUTLASS_HOST_DEVICE
ConstIterator &operator+=(Index idx) {
offset_ += ref_.tensor_stride * idx;
return *this;
}
/// Moves to the previous tensor
CUTLASS_HOST_DEVICE
ConstIterator &operator--() {
offset_ -= ref_.tensor_stride;
return *this;
}
/// Moves to the previous tensor
CUTLASS_HOST_DEVICE
ConstIterator operator--(int) {
ConstIterator ret(*this);
offset_ -= ref_.tensor_stride;
return ret;
}
/// Returns an iterator moved forward by (idx) amount
CUTLASS_HOST_DEVICE
ConstIterator operator-(Index idx) {
return ConstIterator(ref_, offset_ - ref_.tensor_stride * idx);
}
/// Moves this iterator by (idx) and returns a reference to self
CUTLASS_HOST_DEVICE
ConstIterator &operator-=(Index idx) {
offset_ -= ref_.tensor_stride * idx;
return *this;
}
/// Returns the difference in offset between two iterators
CUTLASS_HOST_DEVICE
LongIndex operator-(ConstIterator const &it) {
return offset_ - it.offset_;
}
};
//
// Data members
//
/// Stride between tensors
LongIndex tensor_stride;
//
// Methods
//
// Default ctor
CUTLASS_HOST_DEVICE
TensorRefBatchStrided(): tensor_stride(0) { }
// Constructs form a tensor reference and
CUTLASS_HOST_DEVICE
TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
TensorRef(ref),
tensor_stride(_tensor_stride) { }
/// Gets the pointer offset
CUTLASS_HOST_DEVICE
LongIndex get_pointer_offset(Index idx) const {
return idx * tensor_stride;
}
// Returns a reference
CUTLASS_HOST_DEVICE
TensorRef at(Index idx = 0) const {
TensorRef ref(*this);
ref.add_pointer_offset(get_pointer_offset(idx));
return ref;
}
/// Returns an iterator
CUTLASS_HOST_DEVICE
ConstIterator begin() {
return ConstIterator(*this);
}
};
/// Helper to construct a TensorRefBatchStrided<> object using type deduction
template <typename TensorRef_>
CUTLASS_HOST_DEVICE
TensorRefBatchStrided<
typename TensorRef_::Storage,
TensorRef_::kRank,
typename TensorRef_::MapFunc,
TensorRef_::kStorageGrank,
typename TensorRef_::Index,
typename TensorRef_::LongIndex
> make_TensorRefBatchStrided(
TensorRef_ const &ref,
typename TensorRef_::LongIndex batch_stride = 0) {
return TensorRefBatchStrided<
typename TensorRef_::Storage,
TensorRef_::kRank,
typename TensorRef_::MapFunc,
TensorRef_::kStorageGrank,
typename TensorRef_::Index,
typename TensorRef_::LongIndex
>(ref, batch_stride);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// This satisfies TensorRefCollection and stores a collection of TensorRef objects. This is a
/// structure of arrays in that the individual members of the TensorRef are held in distinct arrays.
///
/// Note, TensorRef maps a logical coordinate space to an n-D array with rank kStorageRank. It
/// maintains a stride vector of similar rank, but the least significant rank is defined to be 1.
///
/// The least significant stride of 1 is not stored, and therefore the number of stride arrays is
/// kStorageRank - 1.
template <
/// Data type of element stored within tensor
typename Storage_,
/// Rank of logical tensor
int Rank_,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
/// Rank of internal n-D array
int StorageRank_ = MapFunc_::kStorageRank,
/// Index type used for coordinates
typename Index_ = int,
/// Index type used for offsets and pointer differences
typename LongIndex_ = long long
>
struct TensorRefArray {
//
// Type definitions
//
/// Element pointed to by the TensorRef
typedef Storage_ Storage;
/// Index type
typedef Index_ Index;
/// Typically, strides in memory can be very large
typedef LongIndex_ LongIndex;
/// Rank of the stride vector
static int const kStorageRank = StorageRank_;
/// TensorRefIterator over TensorRef objects in TensorRefArray
class ConstIterator {
public:
/// Containing class's tensor rev
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
private:
/// Reference to the TensorRefArray
TensorRefArray const &ref_;
/// Index into TensorRefArray
int idx_;
public:
/// Constructs a ConstIterator over the TensorRef objects
CUTLASS_HOST_DEVICE
ConstIterator(TensorRefArray const &ref, int idx = 0): ref_(ref), idx_(idx) { }
/// Obtains a TensorRef pointed to by this iterator
CUTLASS_HOST_DEVICE
TensorRef operator*() const {
return ref_.reference(idx_);
}
/// Advances to next TensorRef
CUTLASS_HOST_DEVICE
ConstIterator &operator++() {
++idx_;
return *this;
}
/// Advances to next TensorRef
CUTLASS_HOST_DEVICE
ConstIterator operator++(int) {
ConstIterator ret(*this);
idx_ ++;
return ret;
}
CUTLASS_HOST_DEVICE
ConstIterator operator+(Index idx) {
return ConstIterator(ref_, idx_ + idx);
}
CUTLASS_HOST_DEVICE
ConstIterator &operator+=(Index idx) {
idx_ += idx;
return *this;
}
CUTLASS_HOST_DEVICE
ConstIterator &operator--() {
--idx_;
return *this;
}
/// Advances to next TensorRef
CUTLASS_HOST_DEVICE
ConstIterator operator--(int) {
ConstIterator ret(*this);
--idx_;
return ret;
}
CUTLASS_HOST_DEVICE
ConstIterator &operator-=(Index idx) {
idx_ -= idx;
return *this;
}
CUTLASS_HOST_DEVICE
ConstIterator operator-(Index idx) {
return ConstIterator(ref_, idx_ + idx);
}
};
/// TensorRef type obtained from the TensorRefArray
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
//
// Data members
//
/// Base addresses
Storage **pointers;
/// Array of strides
Index *strides[kStorageRank - 1];
//
// Methods
//
// Default ctor
CUTLASS_HOST_DEVICE
TensorRefArray() { }
// Construct from pointers to arrays to strides
CUTLASS_HOST_DEVICE
TensorRefArray(
Storage **_pointers,
Index _strides[kStorageRank - 1]): pointers(_pointers) {
// Copy pointers to strides arrays
for (int i = 0; i < kStorageRank - 1; ++i) {
strides[i] = _strides[i];
}
}
// Returns a TensorRef at the given index in the collection
CUTLASS_HOST_DEVICE
TensorRef at(Index idx = 0) const {
Coord<kStorageRank - 1, Index> stride;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStorageRank - 1; ++i) {
stride[i] = strides[idx][i];
}
return TensorRef(pointers[idx], stride);
}
/// Returns an TesnorRefIterator over the TensorRef objects in this collection
CUTLASS_HOST_DEVICE
ConstIterator begin() {
return ConstIterator(*this);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,266 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a structure containing strides and a pointer to tensor data.
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
data storage and is therefore lightweight and may be embedded in larger tensor objects or
memory structures.
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
linear memory.
*/
#pragma once
#include <cmath>
#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines a view into a logical tensor
template <
/// Data type of element stored within tensor
typename Storage_,
/// Rank of logical tensor
int Rank_ = 4,
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
/// Rank of internal n-D array
int StorageRank_ = MapFunc_::kStorageRank,
/// Index type used for coordinates
typename Index_ = int,
/// Index type used for offsets and pointer differences
typename LongIndex_ = long long
>
class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
public:
/// Base tensor reference
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> Base;
/// Tensor reference to of constant value
typedef TensorRef<
typename platform::remove_const<Storage_>::type const,
Rank_,
MapFunc_,
StorageRank_,
Index_,
LongIndex_> ConstTensorRef;
/// Base tensor reference
typedef Base TensorRef;
/// Storage type
typedef typename Base::Storage Storage;
/// Index type
typedef typename Base::Index Index;
/// Coordinate in logical tensor space
typedef typename TensorRef::TensorCoord TensorCoord;
/// Coordinate in storage n-D array
typedef typename TensorRef::StorageCoord StorageCoord;
/// Stride vector in storage coordinate space
/// Least significant stride is = 1 and not stored
typedef typename TensorRef::StrideVector StrideVector;
/// TensorView of constant value
typedef TensorView<
typename platform::remove_const<Storage>::type const,
Rank_,
MapFunc_,
StorageRank_,
Index_,
LongIndex_> ConstTensorView;
//
// Definitions included for backwards compatibility - to be removed in next major release
//
/// Coordinate in logical tensor space
typedef TensorCoord Coord_t;
/// Logical rank of tensor index space
static int const Rank = Base::kRank;
/// Type used to compute the offset of an element to the base of a tensor
typedef typename Base::LongIndex Offset_t;
/// Base class
typedef TensorRef TensorRef_t;
/// TensorRef to const-valued type
typedef typename TensorRef::ConstTensorRef ConstTensorRef_t;
private:
//
// Data members
//
/// Dimensions of coordinate (independent of stride)
TensorCoord size_;
public:
//
// Device and Host Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
TensorView() {}
/// Constructs a TensorView from a TensorRef and size
CUTLASS_HOST_DEVICE
TensorView(Base const& _ref, TensorCoord const& _size) : Base(_ref), size_(_size) {}
/// Constructs a TensorView from a pointer, a stride vector, and size
CUTLASS_HOST_DEVICE
TensorView(
Storage *ptr,
StrideVector const &stride,
TensorCoord const& size
):
Base(ptr, stride), size_(size) {}
/// Constructs a TensorView from a pointer, a stride vector, and size
CUTLASS_HOST_DEVICE
TensorView(
Storage *ptr,
StorageCoord const &stride,
TensorCoord const& size
):
Base(ptr, stride), size_(size) {}
/// Updates the reference and size of a Tensor_view object
CUTLASS_HOST_DEVICE
void reset(Base const& _ref = Base(), TensorCoord const& _size = TensorCoord()) {
Base::operator=(_ref);
size_ = _size;
}
/// Accesses the size
CUTLASS_HOST_DEVICE
TensorCoord const& size() const { return size_; }
/// Accesses the size
CUTLASS_HOST_DEVICE
Index size(int dim) const { return size_.at(dim); }
/// Assigns the Tensor_view
CUTLASS_HOST_DEVICE
TensorView& operator=(TensorView const& _tensor) {
Base::operator=(_tensor);
size_ = _tensor.size_;
return *this;
}
/// Determines whether a location is within a tensor
CUTLASS_HOST_DEVICE
bool contains(TensorCoord const& coord) const {
CUTLASS_PRAGMA_UNROLL
for (int dim = 0; dim < Rank_; ++dim) {
if (coord[dim] >= size_[dim]) {
return false;
}
}
return true;
}
/// Returns a TensorRef pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
TensorRef ref() const {
return TensorRef(*this);
}
/// Returns a TensorRef pointing to the first element of the tensor.
CUTLASS_HOST_DEVICE
ConstTensorRef const_ref() const {
return ConstTensorRef(*this);
}
/// Returns a Tensor_view given location and size quantities
CUTLASS_HOST_DEVICE
TensorView subview(TensorCoord const& location, TensorCoord size) const {
return TensorView((*this) + location, size.clamp(size_ - location));
}
/// Returns the number of scalar elements needed to store tensor
CUTLASS_HOST_DEVICE
size_t capacity() const {
int max_rank = 0;
StorageCoord mapped_size(this->map(size()));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Base::kStorageRank; ++i) {
if (!i ||
this->stride(i) * mapped_size[i] > this->stride(max_rank) * mapped_size[max_rank]) {
max_rank = i;
}
}
return this->stride(max_rank) * mapped_size[max_rank];
}
/// Returns a TensorView offset by a given amount
CUTLASS_HOST_DEVICE
TensorView operator+(TensorCoord const& b) const {
TensorView result(*this);
result.add_pointer_offset(this->offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorView& operator+=(TensorCoord const& b) {
this->add_pointer_offset(this->offset(b));
return *this;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorView operator-(TensorCoord const& b) const {
TensorRef result(*this);
result.add_pointer_offset(-this->offset(b));
return result;
}
/// Returns a TensorRef offset by a given amount
CUTLASS_HOST_DEVICE
TensorView& operator-=(TensorCoord const& b) {
this->add_pointer_offset(-this->offset(b));
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,168 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a fragment based on a Shape<> template.
*/
#pragma once
#include "cutlass/shape.h"
#include "cutlass/fragment.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/zip_tensor_ref.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Class for storing a tile in memory and accessing it through a tensor ref
template <typename Scalar_, typename Shape_>
struct TileAllocation {
//
// Type definitions
//
/// Scalar element
typedef Scalar_ Scalar;
/// The actual storage (may differ from the scalar type)
typedef typename StorageType<sizeof(Scalar)>::Type Storage;
/// Size of the allocation in units of scalars
typedef Shape_ Shape;
/// Strides
typedef typename ShapeStrides<Shape, 1>::Shape Strides;
/// Defines the tensor reference for this allocation
typedef TensorRef<Scalar const, 4> ConstTensorRef;
/// Defines the tensor reference for this allocation
typedef TensorRef<Scalar, 4> TensorRef;
/// View of memory
typedef TensorView<Scalar const, 4> ConstTensorView;
/// View of memory
typedef TensorView<Scalar, 4> TensorView;
//
// Data members
//
/// Storage
Storage storage[Shape::kD][Shape::kH][Shape::kW][Shape::kC];
//
// Methods
//
/// Returns a pointer to the raw data
CUTLASS_DEVICE
Scalar *data() { return reinterpret_cast<Scalar *>(&storage[0][0][0][0]); }
/// Returns a const pointer to the raw data
CUTLASS_DEVICE
Scalar const *data() const { return reinterpret_cast<Scalar const *>(&storage[0][0][0][0]); }
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
TensorRef reference() {
return TensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC));
}
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
ConstTensorRef reference() const {
return ConstTensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC));
}
/// Returns a TensorView object pointing to the data
CUTLASS_DEVICE
TensorView view() {
return TensorView(
data(),
make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC),
make_Coord(Shape::kD, Shape::kH, Shape::kW, Shape::kC));
}
/// Returns a TensorView object pointing to the data
CUTLASS_DEVICE
ConstTensorView view() const {
return TensorView(
data(),
make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC),
make_Coord(Shape::kD, Shape::kH, Shape::kW, Shape::kC));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Manages a pair of tile allocations as if they are one allocation
template <typename First_, typename Second_>
struct ZipTileAllocation {
//
// Type definitions
//
/// First tensor allocation
typedef First_ First;
/// Second tensor allocation
typedef Second_ Second;
/// Defines the tensor reference for this allocation
typedef ZipTensorRef<typename First::TensorRef, typename Second::TensorRef> TensorRef;
/// Defines the tensor reference for this allocation
typedef ZipTensorRef<typename First::ConstTensorRef, typename Second::ConstTensorRef>
ConstTensorRef;
//
// Data members
//
/// First tensor allocation
First first;
/// Second tensor allocation
Second second;
//
// Methods
//
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
TensorRef reference() { return TensorRef(first.reference(), second.reference()); }
/// Returns a TensorRef object pointing to the data
CUTLASS_DEVICE
ConstTensorRef reference() const { return ConstTensorRef(first.reference(), second.reference()); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,194 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a coordinate used for the CUTLASS 4-D tile structure.
*/
#pragma once
#include "cutlass/coord.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// TileCoord wraps Coord<4, int> to provide a helper for accessing named dimensions. Classes
/// expecting a coordinate in the rank=4 index space of a CUTLASS tile structure should use TileCoord.
template <typename Index_ = int>
struct TileCoord : public Coord<4, Index_> {
/// Index type
typedef Index_ Index;
/// Underlying Coord<4>
typedef Coord<4, Index> Base;
/// D dimension
static int kD = 0;
/// H dimension
static int kH = 1;
/// W dimension
static int kW = 2;
/// C dimension
static int kC = 3;
//
// Methods
//
/// Default ctor
CUTLASS_HOST_DEVICE
TileCoord() { }
/// Constructs from Coord<3> and infers coord[kC] = 0
CUTLASS_HOST_DEVICE
TileCoord(Coord<3, Index> const &coord):
Base(make_Coord(coord[0], coord[1], coord[2], 0)) { }
/// Constructs from Coord<4>
CUTLASS_HOST_DEVICE
TileCoord(Coord<4, Index> const &coord): Base(coord) { }
/// Constructs from an array of coordinate elements
CUTLASS_HOST_DEVICE
TileCoord(Index coord[4]): Base(coord) { }
/// Helper to construct from a row and column
CUTLASS_HOST_DEVICE
TileCoord(Index d, Index h, Index w, Index c): Base(make_Coord(d, h, w, c)) { }
/// Returns the D element of the coordinate
CUTLASS_HOST_DEVICE
Index const & d() const { return this->at(kD); }
/// Returns the D element of the coordinate
CUTLASS_HOST_DEVICE
Index & d() { return this->at(kD); }
/// Returns the H element of the coordinate
CUTLASS_HOST_DEVICE
Index const & h() const { return this->at(kH); }
/// Returns the H element of the coordinate
CUTLASS_HOST_DEVICE
Index & h() { return this->at(kH); }
/// Returns the W element of the coordinate
CUTLASS_HOST_DEVICE
Index const & w() const { return this->at(kW); }
/// Returns the W element of the coordinate
CUTLASS_HOST_DEVICE
Index & w() { return this->at(kW); }
/// Returns the Celement of the coordinate
CUTLASS_HOST_DEVICE
Index const & c() const { return this->at(kC); }
/// Returns the C element of the coordinate
CUTLASS_HOST_DEVICE
Index & c() { return this->at(kC); }
/// Gets H and W dimensions as a Coord<2>
CUTLASS_HOST_DEVICE
Coord<2> hw() const {
return make_Coord(h(), w());
}
/// Gets H, W, and C dimensions as a Coord<3>
CUTLASS_HOST_DEVICE
Coord<3> hwc() const {
return make_Coord(h(), w(), c());
}
/// Gets D, H, and W dimensions as a Coord<3>
CUTLASS_HOST_DEVICE
Coord<3> dhw() const {
return make_Coord(d(), h(), w());
}
//
// Coord operators
//
/// Element-wise addition
CUTLASS_HOST_DEVICE
TileCoord operator+(Base const& b) const {
return TileCoord(Base::operator+(b));
}
/// Element-wise subtraction
CUTLASS_HOST_DEVICE
TileCoord operator-(Base const& b) const {
return TileCoord(Base::operator-(b));
}
/// Element-wise multiplication
CUTLASS_HOST_DEVICE
TileCoord operator*(Base const& b) const {
return TileCoord(Base::operator*(b));
}
/// Element-wise division
CUTLASS_HOST_DEVICE
TileCoord operator/(Base const& b) const {
return TileCoord(Base::operator/(b));
}
/// In-place addition
CUTLASS_HOST_DEVICE
TileCoord& operator+=(Base const& b) {
Base::operator+=(b);
return *this;
}
/// In-place subtraction
CUTLASS_HOST_DEVICE
TileCoord& operator-=(Base const& b) {
Base::operator-=(b);
return *this;
}
/// In-place multiplication
CUTLASS_HOST_DEVICE
TileCoord& operator*=(Base const& b) {
Base::operator*=(b);
return *this;
}
/// In-place division
CUTLASS_HOST_DEVICE
TileCoord& operator/=(Base const& b) {
Base::operator/=(b);
return *this;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

File diff suppressed because it is too large Load Diff

View File

@ -1,378 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implements the tile stream concept, composing an iterator with a transformation. Offers
split-phase semantics, separating the initiation of an asynchronous memory operation with a
fence forcing it to complete.
*/
#pragma once
// clang-format off
#include "cutlass/convert.h"
#include "cutlass/tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic stream for loading and transforming fragments
template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct TileLoadStream {
//
// Type definitions
//
/// TileLoadIterator
typedef Iterator_ Iterator;
/// Transformer
typedef Transformer_ Transformer;
/// Fragment fetched from source memory
typedef typename Iterator::Fragment Fragment;
/// Output fragment from transformer
typedef typename Transformer::OutputFragment TransformedFragment;
/// Tensor reference expected by the stream
typedef typename Iterator::TensorRef TensorRef;
/// Empty predicate vector struct
struct PredicateVector {};
/// Index type
typedef typename Iterator::Index Index;
/// Parameters object used to construct generic load stream
struct Params {
/// Parameters to the iterator
typename Iterator::Params iterator;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
Params() {}
/// Constructor with iterator params
CUTLASS_HOST_DEVICE
Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
};
//
// Data members
//
/// Iterator to load tiles
Iterator iterator;
/// Fragment loaded via iterator
Fragment fetched_fragment;
/// Transformation applied to fragments
Transformer transformer;
/// Transformed fragment from transformer
TransformedFragment transformed_fragment;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
TileLoadStream(Params const &_params, TensorRef const &_ref)
: iterator(_params.iterator, _ref) {}
/// Ctor
CUTLASS_DEVICE
TileLoadStream(Params const &_params,
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
): iterator(_params.iterator, threadblock_offset) { }
/// Loads a tile and increments the iterator
CUTLASS_DEVICE
void copy() { iterator.load_post_increment(fetched_fragment); }
/// Commits the fetched fragment and applies a transformation
CUTLASS_DEVICE
void commit() { transformer.transform(fetched_fragment, transformed_fragment); }
/// Accesses the loaded, transformed fragment
CUTLASS_DEVICE
Fragment &intermediate_fragment() { return fetched_fragment; }
/// Accesses the loaded, transformed fragment
CUTLASS_DEVICE
TransformedFragment &fragment() { return transformed_fragment; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic stream for transforming and storing fragments
template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct TileStoreStream {
//
// Type definitions
//
/// TileLoadIterator
typedef Iterator_ Iterator;
/// Transformer
typedef Transformer_ Transformer;
/// Source fragment
typedef typename Transformer::InputFragment Fragment;
/// Transformed fragment, compatible with Iterator::Fragment
typedef typename Transformer::OutputFragment TransformedFragment;
/// Tensor reference expected by the underlying iterator
typedef typename Iterator::TensorRef TensorRef;
/// Empty predicate vector struct
struct PredicateVector {};
/// Index type
typedef typename Iterator::Index Index;
/// Parameters used to construct the stream
struct Params {
/// Parameters to the iterator
typename Iterator::Params iterator;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
Params() {}
/// Constructor with iterator params
CUTLASS_HOST_DEVICE
Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
};
//
// Data members
//
/// Iterator to store tiles
Iterator iterator;
/// Transformation applied to inputs
Transformer transformer;
/// Source fragment
Fragment source_fragment;
/// Transformed fragment from transformer
TransformedFragment transformed_fragment;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
TileStoreStream(Params const &_params, TensorRef const &_ref)
: iterator(_params.iterator, _ref) {}
/// Ctor
CUTLASS_DEVICE
TileStoreStream(Params const &_params,
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
): iterator(_params.iterator, threadblock_offset) { }
/// Stores a fragment and increments the iterator
CUTLASS_DEVICE
void copy() {
transformer.transform(source_fragment, transformed_fragment);
iterator.store_post_increment(transformed_fragment);
}
/// Stores a fragment and increments the iterator
CUTLASS_DEVICE
void copy(Fragment const &frag) {
source_fragment = frag;
copy();
}
/// Commits the store operation
CUTLASS_DEVICE
void commit() {}
/// Accesses the transformed fragment
CUTLASS_DEVICE
Fragment &fragment() { return source_fragment; }
/// Accesses the fragment after trasnforming
CUTLASS_DEVICE
TransformedFragment &intermediate_fragment() { return transformed_fragment; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic stream for loading and transforming fragments
template <typename Iterator_,
typename PredicateFunctor_ =
RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct PredicatedTileLoadStream : public TileLoadStream<Iterator_, Transformer_> {
//
// Type definitions
//
typedef TileLoadStream<Iterator_, Transformer_> Base;
/// TileLoadIterator
typedef Iterator_ Iterator;
/// Predicate functor
typedef PredicateFunctor_ PredicateFunctor;
/// Transformer
typedef Transformer_ Transformer;
/// Fragment fetched from source memory
typedef typename Base::Fragment Fragment;
/// Output fragment from transformer
typedef typename Base::TransformedFragment TransformedFragment;
/// Parameters object used to construct generic load stream
typedef typename Base::Params Params;
//
// Data members
//
/// Predicates
typename Iterator::PredicateVector predicates;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
PredicatedTileLoadStream(Params const &_params,
Coord<3> const &bounds,
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
: Base(_params, threadblock_offset) {
this->iterator.initialize_predicates(
predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
}
/// Loads a tile and increments the iterator
CUTLASS_DEVICE
void copy() { this->iterator.load_post_increment(this->fetched_fragment, predicates.begin()); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Generic stream for transforming and storing fragments
template <typename Iterator_,
typename PredicateFunctor_ =
RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
typename Transformer_ = Copy<typename Iterator_::Fragment> >
struct PredicatedTileStoreStream : public TileStoreStream<Iterator_, Transformer_> {
//
// Type definitions
//
typedef TileStoreStream<Iterator_, Transformer_> Base;
/// TileLoadIterator
typedef Iterator_ Iterator;
/// Predicate functor
typedef PredicateFunctor_ PredicateFunctor;
/// Transformer
typedef Transformer_ Transformer;
/// Fragment fetched from source memory
typedef typename Base::Fragment Fragment;
/// Output fragment from transformer
typedef typename Base::TransformedFragment TransformedFragment;
/// Parameters object used to construct generic load stream
typedef typename Base::Params Params;
//
// Data members
//
/// Predicates
typename Iterator::PredicateVector predicates;
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
PredicatedTileStoreStream(Params const &_params,
Coord<3> const &bounds,
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
: Base(_params, threadblock_offset) {
this->iterator.initialize_predicates(
predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
}
/// Stores the fragment and increments the iterator
CUTLASS_DEVICE
void copy() {
this->transformer.transform(this->source_fragment, this->transformed_fragment);
this->iterator.store_post_increment(this->transformed_fragment, predicates.begin());
}
/// Stores the fragment and increments the iterator
CUTLASS_DEVICE
void copy(Fragment const &frag) {
this->source_fragment = frag;
copy();
}
/// Commits the store operation
CUTLASS_DEVICE
void commit() {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
// clang-format on

View File

@ -1,240 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines tile traits for several tile partitioning arrangements of threads expected to
achieve efficient streaming performance.
*/
#pragma once
#include "cutlass/tile_iterator.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Basic thread offset function computed from a thread shape
template <typename ThreadShape>
struct TiledThreadOffset {
/// Computes the logical coordinate from thread shape
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
Coord<4> thread_offset;
int index = threadIdx.x;
thread_offset[3] = (index % ThreadShape::kC);
index = (index / ThreadShape::kC);
thread_offset[2] = (index % ThreadShape::kW);
index = (index / ThreadShape::kW);
thread_offset[1] = (index % ThreadShape::kH);
index = (index / ThreadShape::kH);
thread_offset[0] = index;
return thread_offset;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Tiling in which the number of threads is greater than the
/// contiguous dimension of the tile.
template <typename Tile_, int Threads>
struct TileTraitsStrideMajor {
/// Shape of tile
typedef Tile_ Tile;
/// Number of participating threads
static int const kThreads = Threads;
// Static assertions
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
"Tiling undefined if elements not divisible by threads.");
static_assert(Tile::kW <= kThreads,
"This specialization assumes there are more threads than the contiguous dimension "
"of the tile.");
/// Shape of threads
typedef Shape<1, kThreads / Tile::kW, Tile::kW, 1> ThreadShape;
/// Delta along each dimension
typedef Shape<1, ThreadShape::kH, 1, 1> Delta;
/// Number of iterations
typedef Shape<1, Tile::kH / ThreadShape::kH, 1, 1> Iterations;
/// Computes the initial offset
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Tiling in which the number of threads is fewer than the tile size
/// in the contiguous dimension.
template <typename Tile_, int Threads>
struct TileTraitsContiguousMajor {
/// Shape of tile
typedef Tile_ Tile;
/// Number of participating threads
static int const kThreads = Threads;
// Static assertions
static_assert(Tile::kW >= kThreads,
"This specialization assumes there are more threads than the contiguous dimension "
"of the tile.");
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
"Tiling undefined if elements not divisible by threads.");
static_assert(!(Tile::kW % kThreads),
"The contiguous size of the tile must be divisible by the number of threads.");
/// Thread shape
typedef Shape<1, 1, kThreads> ThreadShape;
/// Delta between each thread's access
typedef Shape<1, 1, kThreads> Delta;
/// Number of iterations
typedef Shape<1, Tile::kH, Tile::kW / kThreads> Iterations;
/// Computes the initial offset
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Tiling in which warps rake across the contiguous dimension
template <typename Tile_, int Threads>
struct TileTraitsWarpRake {
/// Shape of tile
typedef Tile_ Tile;
/// Number of participating threads
static int const kThreads = Threads;
/// Hard-coded warp size
static int const kWarpSize = 32;
/// Number of participating warps
static int const kWarpCount = kThreads / kWarpSize;
// Static assertions
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
"Tiling undefined if elements not divisible by threads.");
static_assert(!(kThreads % kWarpSize), "Number of threads must be divisible by the warp size.");
static_assert(!(Tile::kW % kWarpSize), "Contiguous dimension must be divisible by the warp size");
/// Warps strip-mined across strided dimension
static int const kWarpsStrided = __NV_STD_MIN(kWarpCount, Tile::kH);
/// Warps stripmined contiguous dimension
static int const kWarpsContiguous = kWarpCount / kWarpsStrided;
/// Arrangement of threads
typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape;
/// The same warp rakes along the contiguous dimension
typedef Shape<1, kWarpsStrided, kWarpSize> Delta;
/// Number of iterations
typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations;
/// Computes the thread offset in (H, W) based on thread ID
struct ThreadOffset {
/// Basic thread offset function computed from a thread shape
CUTLASS_HOST_DEVICE
Coord<4> operator()() const {
int tid = threadIdx.x;
int warp = (tid / kWarpSize);
int lane = (tid % kWarpSize);
static int const kWarpSpanContiguous = kWarpSize * Iterations::kW;
int warp_w = (warp % kWarpsContiguous);
int warp_h = (warp / kWarpsContiguous);
return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0);
}
};
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Chooses 'best' shape to enable warp raking along contiguous dimension if possible.
template <typename Tile_, int Threads>
struct TileTraitsStandard {
/// Shape of tile
typedef Tile_ Tile;
/// Number of participating threads
static int const kThreads = Threads;
/// Hard-coded warp size
static int const kWarpSize = 32;
/// Number of participating warps
static int const kWarpCount = kThreads / kWarpSize;
/// By default, do not do scalar loads
static int const kAccessSize = 1;
// Static assertions
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
"Tiling undefined if elements not divisible by threads.");
/// Choose the stride-major contiguous tiling if the contiguous dimension is
/// smaller than the warp size. Otherwise, if it is divisible by the warp size,
/// choose the warp rake arrangement.
typedef typename platform::conditional <
Tile::kW<kWarpSize,
TileTraitsStrideMajor<Tile, Threads>,
typename platform::conditional<!(Tile::kW % kWarpSize),
TileTraitsWarpRake<Tile, Threads>,
TileTraitsContiguousMajor<Tile, Threads> >::type>::
type Traits;
/// Delta between accesses
typedef typename Traits::Delta Delta;
/// Delta between each thread's access
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
/// Number of accesses
typedef typename Traits::Iterations Iterations;
/// Thread offset functor
typedef typename Traits::ThreadOffset ThreadOffset;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,457 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cuComplex.h>
#include "cutlass/cutlass.h"
#include <iosfwd>
namespace cutlass {
namespace platform {
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Accessors for CUDA complex types
//
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
float const &real(cuFloatComplex const &z) { return z.x; }
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
float &real(cuFloatComplex &z) { return z.x; }
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
double const &real(cuDoubleComplex const &z) { return z.x; }
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
double &real(cuDoubleComplex &z) { return z.x; }
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
float const &imag(cuFloatComplex const &z) { return z.y; }
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
float &imag(cuFloatComplex &z) { return z.y; }
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
double const &imag(cuDoubleComplex const &z) { return z.y; }
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
double &imag(cuDoubleComplex &z) { return z.y; }
//////////////////////////////////////////////////////////////////////////////////////////////////
/// Class for representing and manipulating complex numbers with conversions from built-in CUDA
/// complex types.
template <typename T>
class complex {
public:
/// Type alias for scalar type
typedef T value_type;
private:
//
// Data members
//
/// Real part
T _real;
/// Imaginary part
T _imag;
public:
//
// Methods
//
/// Constructor
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
complex(T r = T(0), T i = T(0)) : _real(r), _imag(i) {}
/// Conversion from cuFloatComplex
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
complex(cuFloatComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
/// Conversion from cuDoubleComplex
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
complex(cuDoubleComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
/// Accesses the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
T const &real() const { return _real; }
/// Accesses the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
T &real() { return _real; }
/// Accesses the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
T const &imag() const { return _imag; }
/// Accesses the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
T &imag() { return _imag; }
/// Converts to cuFloatComplex
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
operator cuFloatComplex() const { return make_cuFloatComplex(real(), imag()); }
/// Converts to cuDoubleComplex
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
CUTLASS_HOST_DEVICE
operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); }
};
//
// Accessors for complex template
//
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T const &real(complex<T> const &z) {
return z.real();
}
/// Returns the real part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T &real(complex<T> &z) {
return z.real();
}
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T const &imag(complex<T> const &z) {
return z.imag();
}
/// Returns the imaginary part of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T &imag(complex<T> &z) {
return z.imag();
}
//
// Output operators
//
template <typename T>
std::ostream &operator<<(std::ostream &out, complex<T> const &z) {
T _r = real(z);
T _i = imag(z);
return out << _r << "+i" << _i;
}
//
// Non-member operators defined for complex types
//
/// Equality operator
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE bool operator==(complex<T> const &lhs, complex<T> const &rhs) {
return real(lhs) == (rhs) && imag(lhs) == imag(rhs);
}
/// Inequality operator
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE bool operator!=(complex<T> const &lhs, complex<T> const &rhs) {
return !(lhs == rhs);
}
/// Addition
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator+(complex<T> const &lhs, complex<T> const &rhs) {
return complex<T>(real(lhs) + real(rhs), imag(lhs) + imag(rhs));
}
/// Subtraction
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator-(complex<T> const &lhs, complex<T> const &rhs) {
return complex<T>(real(lhs) - real(rhs), imag(lhs) - imag(rhs));
}
/// Multiplication
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator*(complex<T> const &lhs, complex<T> const &rhs) {
return complex<T>(real(lhs) * real(rhs) - imag(lhs) * imag(rhs),
real(lhs) * imag(rhs) + imag(lhs) * real(rhs));
}
/// Scalar Multiplication
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator*(complex<T> const &lhs, T const &s) {
return complex<T>(real(lhs) * s, imag(lhs) * s);
}
/// Scalar Multiplication
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator*(T const &s, complex<T> const &rhs) {
return complex<T>(s * real(rhs), s * imag(rhs));
}
/// Division
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator/(complex<T> const &lhs, complex<T> const &rhs) {
T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
return complex<T>((real(lhs) * (rhs) + imag(lhs) * imag(rhs)) / d,
(imag(lhs) * (rhs)-real(lhs) * imag(rhs)) / d);
}
/// Scalar Division
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator/(complex<T> const &lhs, T const &s) {
return complex<T>(real(lhs) / s, imag(lhs) / s);
}
/// Scalar divided by complex
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> operator/(T const &s, complex<T> const &rhs) {
T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
return complex<T>((s * (rhs)) / d, -(s * imag(rhs)) / d);
}
/// Addition
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> &operator+=(complex<T> &lhs, complex<T> const &rhs) {
lhs = (lhs + rhs);
return lhs;
}
/// Subtraction
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> &operator-=(complex<T> &lhs, complex<T> const &rhs) {
lhs = (lhs - rhs);
return lhs;
}
/// Multiplication
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<T> &lhs, complex<T> const &rhs) {
lhs = (lhs * rhs);
return lhs;
}
/// Scalar multiplication
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<T> &lhs, T s) {
lhs = (lhs * s);
return lhs;
}
/// Division
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> &operator/=(complex<T> &lhs, complex<T> const &rhs) {
lhs = (lhs / rhs);
return lhs;
}
//
// Non-member functions defined for complex numbers
//
/// Returns the magnitude of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T abs(complex<T> const &z) {
return sqrt(norm(z));
}
/// Returns the magnitude of the complex number
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T arg(complex<T> const &z) {
return atan2(imag(z), real(z));
}
/// Returns the squared magnitude
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE T norm(complex<T> const &z) {
return real(z) * real(z) + imag(z) * imag(z);
}
/// Returns the complex conjugate
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
return complex<T>(real(z), -imag(z));
}
/// Projects the complex number z onto the Riemann sphere
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> proj(complex<T> const &z) {
T d = real(z) * real(z) + imag(z) * imag(z) + T(1);
return complex<T>((T(2) * real(z)) / d, (T(2) * imag(z)) / d);
}
/// Returns a complex number with magnitude r and phase theta
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
return complex<T>(r * cos(theta), r * sin(theta));
}
/// Computes the complex exponential of z.
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
}
/// Computes the complex exponential of z.
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
return complex<T>(log(abs(z)), arg(z));
}
/// Computes the complex exponential of z.
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
return log(z) / T(log(T(10)));
}
/// Computes the square root of complex number z
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> sqrt(complex<T> const &z) {
return sqrt(T(2)) / T(2) *
complex<T>(sqrt(sqrt(norm(z)) + real(z)),
(imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z)));
}
/// Computes the cosine of complex z.
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> cos(complex<T> const &z) {
return (exp(z) + exp(-z)) / T(2);
}
/// Computes the sin of complex z.
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
// host-only type
template <typename T>
CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
return (exp(-z) - exp(z)) * complex<T>(T(0), T(1) / T(2));
}
//////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace platform
} // namespace cutlass

View File

@ -1,165 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/**
* \file
* \brief Math utilities
*/
#include "cutlass/util/platform.h"
namespace cutlass {
/******************************************************************************
* Static math utilities
******************************************************************************/
/**
* Statically determine if N is a power-of-two
*/
template <int N>
struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
/**
* Statically determine log2(N), rounded down
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_down {
/// Static logarithm value
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_down<N, 1, Count> {
enum { value = Count };
};
/**
* Statically determine log2(N), rounded up
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_up {
/// Static logarithm value
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_up<N, 1, Count> {
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
};
/**
* Statically estimate sqrt(N) to the nearest power-of-two
*/
template <int N>
struct sqrt_est {
enum { value = 1 << (log2_up<N>::value / 2) };
};
/**
* For performing a constant-division with a compile-time assertion that the
* Divisor evenly-divides the Dividend.
*/
template <int Dividend, int Divisor>
struct divide_assert {
enum { value = Dividend / Divisor };
static_assert((Dividend % Divisor == 0), "Not an even multiple");
};
/******************************************************************************
* Rounding
******************************************************************************/
/**
* Round dividend up to the nearest multiple of divisor
*/
template <typename dividend_t, typename divisor_t>
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
return ((dividend + divisor - 1) / divisor) * divisor;
}
/**
* Greatest common divisor
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
for (;;) {
if (a == 0) return b;
b %= a;
if (b == 0) return a;
a %= b;
}
}
/**
* Least common multiple
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
value_t temp = gcd(a, b);
return temp ? (a / temp * b) : 0;
}
/**
* log2 computation, what's the
* difference between the below codes and
* log2_up/down codes?
*/
template <typename value_t>
CUTLASS_HOST_DEVICE value_t clz(value_t x) {
for (int i = 31; i >= 0; --i) {
if ((1 << i) & x) return 31 - i;
}
return 32;
}
template <typename value_t>
CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
int a = 31 - clz(x);
a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
return a;
}
/******************************************************************************
* Min/Max
******************************************************************************/
template <int A, int B>
struct Min {
static int const kValue = (A < B) ? A : B;
};
template <int A, int B>
struct Max {
static int const kValue = (A > B) ? A : B;
};
} // namespace cutlass

View File

@ -1,122 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/**
* \file
* \brief Debugging and logging functionality
*/
#include <stdio.h>
namespace cutlass {
/******************************************************************************
* Debug and logging macros
******************************************************************************/
/**
* Formats and prints the given message to stdout
*/
#if !defined(CUDA_LOG)
#if !defined(__CUDA_ARCH__)
#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
#else
#define CUDA_LOG(format, ...) \
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
blockIdx.x, \
blockIdx.y, \
blockIdx.z, \
threadIdx.x, \
threadIdx.y, \
threadIdx.z, \
__VA_ARGS__);
#endif
#endif
/**
* Formats and prints the given message to stdout only if DEBUG is defined
*/
#if !defined(CUDA_LOG_DEBUG)
#ifdef DEBUG
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
#else
#define CUDA_LOG_DEBUG(format, ...)
#endif
#endif
/**
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code)
* along with the supplied source context.
*
* \return The CUDA error.
*/
__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
const char* filename,
int line) {
(void)filename;
(void)line;
if (error) {
#if !defined(__CUDA_ARCH__)
fprintf(
stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
fflush(stderr);
#else
printf("CUDA error %d [%s, %d]\n", error, filename, line);
#endif
}
return error;
}
/**
* \brief Perror macro
*/
#ifndef CUDA_PERROR
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)
#endif
/**
* \brief Perror macro with exit
*/
#ifndef CUDA_PERROR_EXIT
#define CUDA_PERROR_EXIT(e) \
if (cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)) { \
exit(1); \
}
#endif
/**
* \brief Perror macro only if DEBUG is defined
*/
#ifndef CUDA_PERROR_DEBUG
#ifdef DEBUG
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
#else
#define CUDA_PERROR_DEBUG(e) (e)
#endif
#endif
} // namespace cutlass

View File

@ -1,47 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief
*/
#pragma once
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Definitions for 1-bit binary and 4-bit integer types
//
struct bin1_t {}; // 1-bit binary type
struct int4_t {}; // 4-bit signed integer type
struct uint4_t {}; // 4-bit unsigned integer type
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,124 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a pair<>
*/
#pragma once
namespace cutlass {
namespace platform {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs an iterator from a pair of iterators
template <typename T1, typename T2>
struct Pair {
typedef T1 first_type;
typedef T2 second_type;
//
// Data members
//
T1 first;
T1 second;
//
// Methods
//
/// Default constructor
CUTLASS_HOST_DEVICE
Pair() { }
/// Constructs a pair
CUTLASS_HOST_DEVICE
Pair(T1 const &first_, T2 const &second_): first(first_), second(second_) { }
};
/// Constructs a pair and deduces types
template <typename T1, typename T2>
Pair<T1, T2> make_Pair(T1 const &first, T2 const &second) {
return Pair<T1, T2>(first, second);
}
/// Equality
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator==(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
return (lhs.first == rhs.first) && (lhs.second == rhs.second);
}
/// Inequality
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator!=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
return !(lhs == rhs);
}
/// Lexical comparison
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator<(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
if (lhs.first < rhs.first) {
return true;
}
else if (rhs.first < lhs.first) {
return false;
}
else if (rhs.second < rhs.second) {
return false;
}
return false;
}
/// Lexical comparison
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator<=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
return !(rhs < lhs);
}
/// Lexical comparison
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator>(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
return (rhs < lhs);
}
/// Lexical comparison
template <typename T1, typename T2>
CUTLASS_HOST_DEVICE
bool operator>=(Pair<T1,T2> const &lhs, Pair<T1,T2> const &rhs) {
return !(lhs < rhs);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace platform
} // namespace cutlass

View File

@ -1,40 +0,0 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#ifndef CUTLASS_PERFORMANCE_TUNING_H
#define CUTLASS_PERFORMANCE_TUNING_H
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
#if defined(__CUDA_ARCH__)
#if defined(_MSC_VER)
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
#else
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
#endif
#else
#define CUTLASS_PRAGMA_UNROLL
#define CUTLASS_PRAGMA_NO_UNROLL
#endif
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
#endif // CUTLASS_PERFORMANCE_TUNING_H

View File

@ -1,379 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a 1D vector of elements held in the registers of each thread.
*/
#pragma once
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
#include <cuda_fp16.h>
#endif
#include "cutlass/util/numeric_types.h"
#include "cutlass/util/platform.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <size_t kAlignment_>
struct AlignedStruct {};
template <>
struct __align__(1) AlignedStruct<1>{};
template <>
struct __align__(2) AlignedStruct<2>{};
template <>
struct __align__(4) AlignedStruct<4>{};
template <>
struct __align__(8) AlignedStruct<8>{};
template <>
struct __align__(16) AlignedStruct<16>{};
template <>
struct __align__(32) AlignedStruct<32>{};
template <>
struct __align__(64) AlignedStruct<64>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kLanes_>
union Vector {
/// The scalar type.
typedef Scalar_ Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
// Make sure that the vector type makes sense.
static_assert(kVectorSize <= 16, "Vector type is too large");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The associated array of scalars.
Scalar scalars[kLanes];
/// The data in registers.
uint32_t registers[kRegisters];
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; }
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
union Vector<half, 1> {
/// The scalar type.
typedef half Scalar;
/// The number of elements in the vector.
enum { kLanes = 1 };
/// The size of the vector.
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
// Make sure that the vector type makes sense.
static_assert(kVectorSize <= 16, "Vector type is too large");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The associated array of scalars.
uint16_t scalars[kLanes];
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const {
return reinterpret_cast<Scalar const&>(scalars[i]);
}
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) {
return reinterpret_cast<Scalar&>(scalars[i]);
}
};
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
template <int kLanes_>
union Vector<half, kLanes_> {
/// The scalar type.
typedef half Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
// Make sure that the vector type makes sense.
static_assert(kVectorSize <= size_t(16), "Vector type is too large");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The associated array of scalars.
uint16_t scalars[kLanes];
/// The data in registers.
uint32_t registers[kRegisters];
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const {
return reinterpret_cast<Scalar const&>(scalars[i]);
}
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) {
return reinterpret_cast<Scalar&>(scalars[i]);
}
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Vector definition for 1-bit binary datatype
template <int kLanes_>
union Vector<bin1_t, kLanes_> {
/// The scalar type.
typedef bin1_t Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes / 8 };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
static_assert((kLanes >= 8) && !(kLanes % 8),
"May only construct vectors of bin1_t that are multiples of 8 bits.");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The data in registers.
uint32_t registers[kRegisters];
/// Default Constructor
CUTLASS_HOST_DEVICE
Vector() {}
/// Constructor to convert from uint32_t type
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE bool operator[](uint32_t i) const {
return ( (registers[i / 32] & (1 << (i % 32))) != 0 );
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Vector definition for 4-bit signed integer datatype
template <int kLanes_>
union Vector<int4_t, kLanes_> {
/// The scalar type.
typedef int4_t Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes / 2 };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
static_assert((kLanes >= 2) && !(kLanes % 2),
"May only construct vectors of int4_t that are multiples of 8 bits.");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The data in registers.
uint32_t registers[kRegisters];
/// Default Constructor
CUTLASS_HOST_DEVICE
Vector() {}
/// Constructor to convert from uint32_t type
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
return (registers[i / 8] >> (i % 8 * 4) & 0x0f)
- 16 * (registers[i / 8] >> (i % 8 * 4 + 3) & 0x01);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Vector definition for 4-bit unsigned integer datatype
template <int kLanes_>
union Vector<uint4_t, kLanes_> {
/// The scalar type.
typedef uint4_t Scalar;
/// The number of elements in the vector.
enum { kLanes = kLanes_ };
/// The size of the vector.
enum { kVectorSize = kLanes / 2 };
/// The number of registers needed to store the vector.
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
static_assert((kLanes >= 2) && !(kLanes % 2),
"May only construct vectors of uint4_t that are multiples of 8 bits.");
/// The aligned storage to make sure we have good alignment.
AlignedStruct<kVectorSize> aligned_;
/// The data in registers.
uint32_t registers[kRegisters];
/// Default Constructor
CUTLASS_HOST_DEVICE
Vector() {}
/// Constructor to convert from uint32_t type
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
/// Accessor to the ith lane.
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
return registers[i / 8] >> (i % 8 * 4) & 0x0f;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_>
CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) {
x = Scalar_(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Element_, int kLanes_ = 1>
struct Vectorize {
typedef Vector<Element_, kLanes_> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kLanes_>
struct Vectorize<Vector<bin1_t, 32>, kLanes_> {
typedef Vector<bin1_t, kLanes_ * 32> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kLanes_>
struct Vectorize<Vector<int4_t, 8>, kLanes_> {
typedef Vector<int4_t, kLanes_ * 8> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kLanes_>
struct Vectorize<Vector<uint4_t, 8>, kLanes_> {
typedef Vector<uint4_t, kLanes_ * 8> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Scalar_, int kLanes_>
CUTLASS_HOST_DEVICE void make_zero(Vector<Scalar_, kLanes_>& vec) {
for (int i = 0; i < Vector<Scalar_, kLanes_>::kRegisters; ++i) {
vec.registers[i] = 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// cutlass::Extent similar to std::extent but applicable to CUTLASS types
//
/// Returns the extent of a scalar or vector
template <typename T>
struct Extent {
static size_t const kValue = 1;
};
/// Returns the number of lanes of a vector if need be
template <typename T, int Lanes>
struct Extent<Vector<T, Lanes> > {
static size_t const kValue = Lanes;
};
/// Returns the number of lanes of a vector if need be
template <typename T, int Lanes>
struct Extent<Vector<T, Lanes> const> {
static size_t const kValue = Lanes;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Traits describing properties of vectors and scalar-as-vectors
template <typename T>
struct VectorTraits {
/// Scalar type
typedef T Scalar;
/// Number of lanes of vector
static int const kLanes = 1;
/// True if the type is actually a cutlass::Vector, otherwise false
static bool const IsVector = false;
/// Type that is always a vector
typedef Vector<T, 1> Vector;
};
/// Partial specialization for actual cutlass::Vector
template <typename T, int Lanes>
struct VectorTraits<Vector<T, Lanes> > {
/// Scalar type
typedef T Scalar;
/// Number of lanes of vector
static int const kLanes = Lanes;
/// Type is actually a cutlass::Vector
static bool const IsVector = true;
/// Type that is always a Vector
typedef Vector<T, Lanes> Vector;
};
/// Partial specialization for actual cutlass::Vector
template <typename T, int Lanes>
struct VectorTraits<Vector<T, Lanes> const> {
/// Scalar type
typedef T Scalar;
/// Number of lanes of vector
static int const kLanes = Lanes;
/// Type is actually a cutlass::Vector
static bool const IsVector = true;
/// Type that is always a Vector
typedef Vector<T, Lanes> Vector;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,236 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Abstractions for loading and storing matrices using the CUDA WMMA API.
*/
#pragma once
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
#define CUTLASS_USE_WMMA_API
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
#define CUTLASS_USE_SUBBYTE_WMMA
#endif
#include "stdio.h"
#if __CUDACC_VER_MAJOR__ >= 10
#include <mma.h>
#else
#include <crt/mma.h>
#endif
#include "cutlass/fragment.h"
#include "cutlass/matrix_traits.h"
#include "cutlass/shape.h"
#include "cutlass/vector.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
template <MatrixLayout::Kind kLayout_>
struct WmmaLayout {
typedef nvcuda::wmma::col_major Layout;
};
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
template <>
struct WmmaLayout<MatrixLayout::kRowMajor> {
typedef nvcuda::wmma::row_major Layout;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps cutlass types to nvcuda::wmma datatypes
template <typename Type_>
struct WmmaDataType{
typedef Type_ Type;
};
#ifdef CUTLASS_USE_SUBBYTE_WMMA
/// Statically maps cutlass::Vector<bin1_t, 32> to nvcuda::wmma::experimental::precision::b1
template<>
struct WmmaDataType<Vector<bin1_t, 32> > {
typedef nvcuda::wmma::experimental::precision::b1 Type;
};
/// Statically maps cutlass::Vector<int4_t, 8> to nvcuda::wmma::experimental::precision::s4
template<>
struct WmmaDataType<Vector<int4_t, 8> > {
typedef nvcuda::wmma::experimental::precision::s4 Type;
};
/// Statically maps cutlass::Vector<uint4_t, 8> to nvcuda::wmma::experimental::precision::u4
template<>
struct WmmaDataType<Vector<uint4_t, 8> > {
typedef nvcuda::wmma::experimental::precision::u4 Type;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment load and store operations
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
struct WmmaMatrix {};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for A operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::matrix_a,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
typename WmmaDataType<Scalar_>::Type,
/// The layout.
typename WmmaLayout<kLayout_>::Layout> {
/// This type.
typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for B operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::matrix_b,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
typename WmmaDataType<Scalar_>::Type,
/// The layout.
typename WmmaLayout<kLayout_>::Layout> {
/// This type.
typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Adapter to nvcuda::wmma fragment accessors for C operand
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
: public nvcuda::wmma::fragment<
/// The nvcuda::wmma operand name.
nvcuda::wmma::accumulator,
/// The dimensions.
WmmaShape_::kW,
WmmaShape_::kH,
WmmaShape_::kD,
/// The scalar.
Scalar_> {
/// This type.
typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
/// The layout.
static MatrixLayout::Kind const kLayout = kLayout_;
/// Fill-in the element.
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
nvcuda::wmma::fill_fragment(*this, x);
return *this;
}
/// Load from memory.
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
nvcuda::wmma::load_matrix_sync(
*this,
pointer,
stride,
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
}
/// Store to memory.
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
nvcuda::wmma::store_matrix_sync(
pointer,
*this,
stride,
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// WmmaMatrix cannot be used in a Union and thus in cannot be used in our Vector implementation.
// The only use of WmmaMatrix in in combination with Vectorize has kLanes == 1. Due to this it is
// safe to keep the Vector->Scalar conversion for WmmaMatrix.
template <GemmOperand::Kind kOperand_,
MatrixLayout::Kind kLayout_,
typename Scalar_,
typename WmmaShape_>
struct Vectorize<WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>, 1> {
typedef WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_> Type;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
#endif // defined CUTLASS_USE_WMMA_API

View File

@ -1,150 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Models a pair of fragments
*/
#pragma once
#include <assert.h>
#include "cutlass/cutlass.h"
#include "cutlass/shape.h"
#include "cutlass/util/cutlass_math.h"
#include "cutlass/vector.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/**
* @brief A template defining \ref fragment_concept
* @concept{fragment_concept}
*/
template <typename First_, typename Second_>
struct ZipFragment {
/// First fragment object
typedef First_ First;
/// Second fragment object
typedef Second_ Second;
/// This class.
typedef ZipFragment<First, Second> This_;
//
// Data members
//
/// First fragment object
First first;
/// Second fragment object
Second second;
//
// Methods
//
/// Default ctor
CUTLASS_DEVICE
ZipFragment() { }
/// Copy ctor
CUTLASS_DEVICE
ZipFragment(First const &_first, Second const &_second): first(_first), second(_second) { }
/// Clear a fragment.
CUTLASS_DEVICE void clear() {
first.clear();
second.clear();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to construct a ZipFragment object
template <typename First, typename Second>
CUTLASS_HOST_DEVICE
ZipFragment<First, Second> make_ZipFragment(First const &first, Second const &second) {
return ZipFragment<First, Second>(first, second);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Zips two convert operations
template <typename First_, typename Second_>
struct ZipConvert {
/// First convert operator
typedef First_ First;
/// Second convert operator
typedef Second_ Second;
/// Defines the input zip fragment
typedef ZipFragment<typename First::InputFragment, typename Second::InputFragment> InputFragment;
/// Defines the output zip fragment
typedef ZipFragment<typename First::OutputFragment, typename Second::OutputFragment>
OutputFragment;
//
//
//
/// First transformer
First first;
/// Second transformer
Second second;
//
//
//
/// Ctor.
CUTLASS_DEVICE ZipConvert() {}
/// Ctor.
CUTLASS_DEVICE ZipConvert(First const &_first, Second const &_second): first(_first), second(_second) { }
/// Transform a fragment.
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
first.transform(src.first, dst.first);
second.transform(src.second, dst.second);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to construct a ZipConvert object
template <typename First, typename Second>
CUTLASS_HOST_DEVICE
ZipConvert<First, Second> make_ZipConvert(First const &first, Second const &second) {
return ZipConvert<First, Second>(first, second);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,77 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines a structure containing a pair of TensorRef-like objects
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/tensor_ref.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename First_, typename Second_>
struct ZipTensorRef {
/// First tensor ref
typedef First_ First;
/// Second tensor ref
typedef Second_ Second;
//
// Data members
//
/// First TensorRef
First first;
/// Second TensorRef
Second second;
//
// Methods
//
CUTLASS_HOST_DEVICE
ZipTensorRef() {}
CUTLASS_HOST_DEVICE
ZipTensorRef(First const& _first, Second const& _second) : first(_first), second(_second) {}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs a ZipTensorRef
template <typename First, typename Second>
CUTLASS_HOST_DEVICE
ZipTensorRef<First, Second> make_ZipTensorRef(First const &first, Second const &second) {
return ZipTensorRef<First, Second>(first, second);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@ -1,291 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Constructs an iterator that owns two tile iterator instances
*/
#pragma once
#include "cutlass/coord.h"
#include "cutlass/zip_tensor_ref.h"
#include "cutlass/zip_fragment.h"
#include "cutlass/util/pair.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs an iterator from a pair of iterators
template <typename First_, typename Second_>
class ZipTileIterator {
public:
/// First iterator type
typedef First_ First;
/// Second iterator type
typedef Second_ Second;
/// Params object
struct Params {
/// Parameters of first iterator
typename First::Params first;
/// Parameters of second iterator
typename Second::Params second;
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params() {}
/// Constructs a parameters object
CUTLASS_HOST_DEVICE
Params(typename First::Params const &_first, typename Second::Params const &_second)
: first(_first), second(_second) {}
};
/// Fragment type
typedef ZipFragment<typename First::Fragment, typename Second::Fragment> Fragment;
/// Predicate vector
typedef typename First::PredicateVector PredicateVector;
/// Index type
typedef platform::Pair<typename First::Index, typename Second::Index> Index;
/// Long index type
typedef platform::Pair<typename First::LongIndex, typename Second::LongIndex> LongIndex;
/// Tensor reference
typedef ZipTensorRef<
typename First::TensorRef,
typename Second::TensorRef> TensorRef;
//
// Data members
//
/// First iterator
First first;
/// Second iterator
Second second;
//
// Methods
//
/// Default constructor
CUTLASS_DEVICE
ZipTileIterator() {}
/// Constructs a zip iterator from params
CUTLASS_DEVICE
ZipTileIterator(Params const &_params, Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
: first(_params.first, threadblock_offset), second(_params.second, threadblock_offset) {}
/// Constructs a zip iterator from iterator instances
CUTLASS_DEVICE
ZipTileIterator(First const &_first, Second const &_second) : first(_first), second(_second) {}
/// Constructs a zip iterator from iterator instances
CUTLASS_DEVICE
ZipTileIterator(TensorRef const &ref) : first(ref.first), second(ref.second) {}
/// Constructs a zip iterator from iterator instances
CUTLASS_DEVICE
ZipTileIterator(Params const &_params, TensorRef const &ref):
first(_params.first, ref.first), second(_params.second, ref.second) {}
//
// Predicate initialization
//
/// Initializes a predicate vector using a RegularTilePredicateFunctor
template <
/// Predicate iterator
typename PredicateIterator>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
Coord<3> const &bounds,
Coord<3> const &block_offset = make_Coord(0,
0,
0)) {
first.initialize_predicates(predicate_it, bounds, block_offset);
}
/// Initializes a predicate vector using an arbitrary predicate functor
template <
/// Predicate iterator
typename PredicateIterator,
/// Functor computing predicates
typename PredicateFunctor>
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
PredicateFunctor const &functor,
Coord<3> const &block_offset) {
first.initialize_predicates(predicate_it, functor, block_offset);
}
//
// No predicates
//
/// Loads a fragment and increments without predicates
template <typename Fragment>
CUTLASS_DEVICE void load_post_increment(Fragment &fragment) {
first.load_post_increment(fragment.first);
second.load_post_increment(fragment.second);
}
/// Loads a fragment and increments without predicates
template <typename Fragment>
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
Coord<4> const &offset) {
first.load_post_increment(fragment.first, offset);
second.load_post_increment(fragment.second, offset);
}
/// Loads a fragment without predicates
template <typename Fragment>
CUTLASS_DEVICE void load(Fragment &fragment) const {
first.load(fragment.first);
second.load(fragment.second);
}
/// Loads a fragment without predicates
template <typename Fragment>
CUTLASS_DEVICE void load(Fragment &fragment,
Coord<4> const &offset) const {
first.load(fragment.first, offset);
second.load(fragment.second, offset);
}
/// Stores a fragment and increments without predicates
template <typename Fragment>
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment) {
first.store_post_increment(fragment.first);
second.store_post_increment(fragment.second);
}
/// Stores a fragment and increments without predicates
template <typename Fragment>
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment,
Coord<4> const &offset) {
first.store_post_increment(fragment.first, offset);
second.store_post_increment(fragment.second, offset);
}
/// Stores a fragment without predicates
template <typename Fragment>
CUTLASS_DEVICE void store(Fragment const &fragment) const {
first.store(fragment.first);
second.store(fragment.second);
}
/// Stores a fragment without predicates
template <typename Fragment>
CUTLASS_DEVICE void store(Fragment const &fragment,
Coord<4> const &offset) const {
first.store(fragment.first, offset);
second.store(fragment.second, offset);
}
//
// With predication
//
/// Loads a fragment and increments, using predicates
template <typename Fragment, typename PredicateIterator>
CUTLASS_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
first.load_post_increment(fragment.first, pred_it);
second.load_post_increment(fragment.second, pred_it);
}
/// Loads a fragment with predicates
template <typename Fragment, typename PredicateIterator>
CUTLASS_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
first.load(fragment.first, pred_it);
second.load(fragment.second, pred_it);
}
/// Loads a fragment and increments, using predicates
template <typename Fragment, typename PredicateIterator>
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
first.store_post_increment(fragment.first, pred_it);
second.store_post_increment(fragment.second, pred_it);
}
/// Loads a fragment with predicates
template <typename Fragment, typename PredicateIterator>
CUTLASS_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
first.store(fragment.first, pred_it);
second.store(fragment.second, pred_it);
}
//
// Advances the iterators
//
/// Increments store iterator to next tile
CUTLASS_DEVICE ZipTileIterator &increment(int count = 1) {
first.increment(count);
second.increment(count);
return *this;
}
/// Increments to next tile
CUTLASS_DEVICE ZipTileIterator &operator++() { return increment(); }
CUTLASS_DEVICE ZipTileIterator &operator+=(int count) { return increment(count); }
/// Adds a vector offset to the underlying iterators
CUTLASS_DEVICE ZipTileIterator &operator+=(Coord<3> const &offset) {
first += offset;
second += offset;
return *this;
}
/// Increments store iterator to previous tile
CUTLASS_DEVICE ZipTileIterator &decrement(int count = 1) {
first.decrement(count);
second.decrement(count);
return *this;
}
/// Increments to subsequent tile
CUTLASS_DEVICE ZipTileIterator &operator--() { return decrement(); }
/// Decrements to previous tile
CUTLASS_DEVICE ZipTileIterator &operator-=(int count) { return decrement(count); }
/// Adds an offset to both iterators
CUTLASS_DEVICE void add_pointer_offset(LongIndex offset) {
first.add_pointer_offset(offset.first);
second.add_pointer_offset(offset.second);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namspace cutlass

View File

@ -1 +0,0 @@
theme: jekyll-theme-minimal

View File

@ -0,0 +1,145 @@
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen 1.8.11"/>
<title>CUTLASS: aligned_buffer.h File Reference</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="jquery.js"></script>
<script type="text/javascript" src="dynsections.js"></script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<script type="text/javascript">
$(document).ready(function() { init_search(); });
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js"],
jax: ["input/TeX","output/HTML-CSS"],
});
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
<td id="projectalign" style="padding-left: 0.5em;">
<div id="projectname">CUTLASS
</div>
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.8.11 -->
<script type="text/javascript">
var searchBox = new SearchBox("searchBox", "search",false,'Search');
</script>
<div id="navrow1" class="tabs">
<ul class="tablist">
<li><a href="index.html"><span>Main&#160;Page</span></a></li>
<li><a href="modules.html"><span>Modules</span></a></li>
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
<li><a href="annotated.html"><span>Classes</span></a></li>
<li class="current"><a href="files.html"><span>Files</span></a></li>
<li>
<div id="MSearchBox" class="MSearchBoxInactive">
<span class="left">
<img id="MSearchSelect" src="search/mag_sel.png"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
alt=""/>
<input type="text" id="MSearchField" value="Search" accesskey="S"
onfocus="searchBox.OnSearchFieldFocus(true)"
onblur="searchBox.OnSearchFieldFocus(false)"
onkeyup="searchBox.OnSearchFieldChange(event)"/>
</span><span class="right">
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
</span>
</div>
</li>
</ul>
</div>
<div id="navrow2" class="tabs2">
<ul class="tablist">
<li><a href="files.html"><span>File&#160;List</span></a></li>
<li><a href="globals.html"><span>File&#160;Members</span></a></li>
</ul>
</div>
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
</div>
<!-- iframe showing the search results (closed by default) -->
<div id="MSearchResultsWindow">
<iframe src="javascript:void(0)" frameborder="0"
name="MSearchResults" id="MSearchResults">
</iframe>
</div>
<div id="nav-path" class="navpath">
<ul>
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li> </ul>
</div>
</div><!-- top -->
<div class="header">
<div class="summary">
<a href="#nested-classes">Classes</a> &#124;
<a href="#namespaces">Namespaces</a> </div>
<div class="headertitle">
<div class="title">aligned_buffer.h File Reference</div> </div>
</div><!--header-->
<div class="contents">
<p>AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared memory.
<a href="#details">More...</a></p>
<div class="textblock"><code>#include &quot;<a class="el" href="cutlass_8h_source.html">cutlass/cutlass.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="array_8h_source.html">cutlass/array.h</a>&quot;</code><br />
</div><div class="textblock"><div class="dynheader">
Include dependency graph for aligned_buffer.h:</div>
<div class="dyncontent">
<div class="center"><img src="aligned__buffer_8h__incl.png" border="0" usemap="#aligned__buffer_8h" alt=""/></div>
<map name="aligned__buffer_8h" id="aligned__buffer_8h">
</map>
</div>
</div><div class="textblock"><div class="dynheader">
This graph shows which files directly or indirectly include this file:</div>
<div class="dyncontent">
<div class="center"><img src="aligned__buffer_8h__dep__incl.png" border="0" usemap="#aligned__buffer_8hdep" alt=""/></div>
<map name="aligned__buffer_8hdep" id="aligned__buffer_8hdep">
</map>
</div>
</div>
<p><a href="aligned__buffer_8h_source.html">Go to the source code of this file.</a></p>
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
Classes</h2></td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1AlignedBuffer.html">cutlass::AlignedBuffer&lt; T, N, Align &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Modifies semantics of cutlass::Array&lt;&gt; to provide guaranteed alignment. <a href="structcutlass_1_1AlignedBuffer.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table><table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
Namespaces</h2></td></tr>
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top"> &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
</div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>
Generated by &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="doxygen.png" alt="doxygen"/>
</a> 1.8.11
</small></address>
</body>
</html>

View File

@ -0,0 +1 @@
6cbc6b81ede44b5f08afd4f4519d56d1

View File

@ -0,0 +1 @@
b26c62930ff7668b89f2ee6624e0be3a

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

156
docs/arch_2mma_8h.html Normal file
View File

@ -0,0 +1,156 @@
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen 1.8.11"/>
<title>CUTLASS: mma.h File Reference</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="jquery.js"></script>
<script type="text/javascript" src="dynsections.js"></script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<script type="text/javascript">
$(document).ready(function() { init_search(); });
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js"],
jax: ["input/TeX","output/HTML-CSS"],
});
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
<td id="projectalign" style="padding-left: 0.5em;">
<div id="projectname">CUTLASS
</div>
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.8.11 -->
<script type="text/javascript">
var searchBox = new SearchBox("searchBox", "search",false,'Search');
</script>
<div id="navrow1" class="tabs">
<ul class="tablist">
<li><a href="index.html"><span>Main&#160;Page</span></a></li>
<li><a href="modules.html"><span>Modules</span></a></li>
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
<li><a href="annotated.html"><span>Classes</span></a></li>
<li class="current"><a href="files.html"><span>Files</span></a></li>
<li>
<div id="MSearchBox" class="MSearchBoxInactive">
<span class="left">
<img id="MSearchSelect" src="search/mag_sel.png"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
alt=""/>
<input type="text" id="MSearchField" value="Search" accesskey="S"
onfocus="searchBox.OnSearchFieldFocus(true)"
onblur="searchBox.OnSearchFieldFocus(false)"
onkeyup="searchBox.OnSearchFieldChange(event)"/>
</span><span class="right">
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
</span>
</div>
</li>
</ul>
</div>
<div id="navrow2" class="tabs2">
<ul class="tablist">
<li><a href="files.html"><span>File&#160;List</span></a></li>
<li><a href="globals.html"><span>File&#160;Members</span></a></li>
</ul>
</div>
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
</div>
<!-- iframe showing the search results (closed by default) -->
<div id="MSearchResultsWindow">
<iframe src="javascript:void(0)" frameborder="0"
name="MSearchResults" id="MSearchResults">
</iframe>
</div>
<div id="nav-path" class="navpath">
<ul>
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
</div>
</div><!-- top -->
<div class="header">
<div class="summary">
<a href="#nested-classes">Classes</a> &#124;
<a href="#namespaces">Namespaces</a> </div>
<div class="headertitle">
<div class="title">arch/mma.h File Reference</div> </div>
</div><!--header-->
<div class="contents">
<p>Templates exposing architecture support for multiply-add operations.
<a href="#details">More...</a></p>
<div class="textblock"><code>#include &quot;<a class="el" href="array_8h_source.html">cutlass/array.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="numeric__types_8h_source.html">cutlass/numeric_types.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="include_2cutlass_2gemm_2gemm_8h_source.html">cutlass/gemm/gemm.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="arch_2mma__sm50_8h_source.html">cutlass/arch/mma_sm50.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="arch_2mma__sm60_8h_source.html">cutlass/arch/mma_sm60.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="arch_2mma__sm61_8h_source.html">cutlass/arch/mma_sm61.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="mma__sm70_8h_source.html">cutlass/arch/mma_sm70.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="mma__sm75_8h_source.html">cutlass/arch/mma_sm75.h</a>&quot;</code><br />
</div><div class="textblock"><div class="dynheader">
Include dependency graph for arch/mma.h:</div>
<div class="dyncontent">
<div class="center"><img src="arch_2mma_8h__incl.png" border="0" usemap="#mma_8h" alt=""/></div>
<map name="mma_8h" id="mma_8h">
</map>
</div>
</div><div class="textblock"><div class="dynheader">
This graph shows which files directly or indirectly include this file:</div>
<div class="dyncontent">
<div class="center"><img src="arch_2mma_8h__dep__incl.png" border="0" usemap="#mma_8hdep" alt=""/></div>
<map name="mma_8hdep" id="mma_8hdep">
</map>
</div>
</div>
<p><a href="arch_2mma_8h_source.html">Go to the source code of this file.</a></p>
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
Classes</h2></td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma.html">cutlass::arch::Mma&lt; Shape_, kThreads_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01ElementAb6e65b2cf5ede7f41cb070a767158dee.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01ElementAb6e65b2cf5ede7f41cb070a767158dee.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table><table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
Namespaces</h2></td></tr>
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top"> &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top"> &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
</div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>
Generated by &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="doxygen.png" alt="doxygen"/>
</a> 1.8.11
</small></address>
</body>
</html>

View File

@ -0,0 +1 @@
7d16b59e6ba0442b8a275a213d5da3a6

View File

@ -0,0 +1 @@
d1fff3f9d55a262110aa6a456caa91e0

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,176 @@
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
<meta name="generator" content="Doxygen 1.8.11"/>
<title>CUTLASS: mma_sm50.h File Reference</title>
<link href="tabs.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="jquery.js"></script>
<script type="text/javascript" src="dynsections.js"></script>
<link href="search/search.css" rel="stylesheet" type="text/css"/>
<script type="text/javascript" src="search/searchdata.js"></script>
<script type="text/javascript" src="search/search.js"></script>
<script type="text/javascript">
$(document).ready(function() { init_search(); });
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js"],
jax: ["input/TeX","output/HTML-CSS"],
});
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
<link href="doxygen.css" rel="stylesheet" type="text/css" />
</head>
<body>
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
<div id="titlearea">
<table cellspacing="0" cellpadding="0">
<tbody>
<tr style="height: 56px;">
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
<td id="projectalign" style="padding-left: 0.5em;">
<div id="projectname">CUTLASS
</div>
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
</td>
</tr>
</tbody>
</table>
</div>
<!-- end header part -->
<!-- Generated by Doxygen 1.8.11 -->
<script type="text/javascript">
var searchBox = new SearchBox("searchBox", "search",false,'Search');
</script>
<div id="navrow1" class="tabs">
<ul class="tablist">
<li><a href="index.html"><span>Main&#160;Page</span></a></li>
<li><a href="modules.html"><span>Modules</span></a></li>
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
<li><a href="annotated.html"><span>Classes</span></a></li>
<li class="current"><a href="files.html"><span>Files</span></a></li>
<li>
<div id="MSearchBox" class="MSearchBoxInactive">
<span class="left">
<img id="MSearchSelect" src="search/mag_sel.png"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
alt=""/>
<input type="text" id="MSearchField" value="Search" accesskey="S"
onfocus="searchBox.OnSearchFieldFocus(true)"
onblur="searchBox.OnSearchFieldFocus(false)"
onkeyup="searchBox.OnSearchFieldChange(event)"/>
</span><span class="right">
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
</span>
</div>
</li>
</ul>
</div>
<div id="navrow2" class="tabs2">
<ul class="tablist">
<li><a href="files.html"><span>File&#160;List</span></a></li>
<li><a href="globals.html"><span>File&#160;Members</span></a></li>
</ul>
</div>
<!-- window showing the filter options -->
<div id="MSearchSelectWindow"
onmouseover="return searchBox.OnSearchSelectShow()"
onmouseout="return searchBox.OnSearchSelectHide()"
onkeydown="return searchBox.OnSearchSelectKey(event)">
</div>
<!-- iframe showing the search results (closed by default) -->
<div id="MSearchResultsWindow">
<iframe src="javascript:void(0)" frameborder="0"
name="MSearchResults" id="MSearchResults">
</iframe>
</div>
<div id="nav-path" class="navpath">
<ul>
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
</div>
</div><!-- top -->
<div class="header">
<div class="summary">
<a href="#nested-classes">Classes</a> &#124;
<a href="#namespaces">Namespaces</a> </div>
<div class="headertitle">
<div class="title">arch/mma_sm50.h File Reference</div> </div>
</div><!--header-->
<div class="contents">
<p>Matrix multiply.
<a href="#details">More...</a></p>
<div class="textblock"><code>#include &quot;<a class="el" href="arch_2mma_8h_source.html">cutlass/arch/mma.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="complex_8h_source.html">cutlass/complex.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="layout_2matrix_8h_source.html">cutlass/layout/matrix.h</a>&quot;</code><br />
<code>#include &quot;<a class="el" href="include_2cutlass_2gemm_2gemm_8h_source.html">cutlass/gemm/gemm.h</a>&quot;</code><br />
</div><div class="textblock"><div class="dynheader">
Include dependency graph for arch/mma_sm50.h:</div>
<div class="dyncontent">
<div class="center"><img src="arch_2mma__sm50_8h__incl.png" border="0" usemap="#mma__sm50_8h" alt=""/></div>
<map name="mma__sm50_8h" id="mma__sm50_8h">
</map>
</div>
</div><div class="textblock"><div class="dynheader">
This graph shows which files directly or indirectly include this file:</div>
<div class="dyncontent">
<div class="center"><img src="arch_2mma__sm50_8h__dep__incl.png" border="0" usemap="#mma__sm50_8hdep" alt=""/></div>
<map name="mma__sm50_8hdep" id="mma__sm50_8hdep">
</map>
</div>
</div>
<p><a href="arch_2mma__sm50_8h_source.html">Go to the source code of this file.</a></p>
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
Classes</h2></td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_004bb3fd76ca2af7b3210676fa9644d95b.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_004bb3fd76ca2af7b3210676fa9644d95b.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_0aa57e6a2e6b5da37d10688bf99419a23.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_0aa57e6a2e6b5da37d10688bf99419a23.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01int_00_00b2dff9ce8caad9aff5bc6a355539161.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01int_00_00b2dff9ce8caad9aff5bc6a355539161.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_76f9d24016e1b4167b16f4d7628c9546.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, complex&lt; float &gt;, LayoutA, complex&lt; float &gt;, LayoutB, complex&lt; float &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_76f9d24016e1b4167b16f4d7628c9546.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_f1c9d2ee842455cd0c5b71d56108d468.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, complex&lt; float &gt;, LayoutA, float, LayoutB, complex&lt; float &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_f1c9d2ee842455cd0c5b71d56108d468.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_00e3e12e263df6506b8cf06c3f4d478b8e.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, float, LayoutA, complex&lt; float &gt;, LayoutB, complex&lt; float &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_00e3e12e263df6506b8cf06c3f4d478b8e.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_30fa42e1ad201df010637cd22fc070a1.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, complex&lt; double &gt;, LayoutA, complex&lt; double &gt;, LayoutB, complex&lt; double &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_30fa42e1ad201df010637cd22fc070a1.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_48b3a43bc03fff93a111ac01abe7e40d.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, complex&lt; double &gt;, LayoutA, double, LayoutB, complex&lt; double &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_48b3a43bc03fff93a111ac01abe7e40d.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_070b94670e040ed5855e5b42d5ca8a443.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, double, LayoutA, complex&lt; double &gt;, LayoutB, complex&lt; double &gt;, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_070b94670e040ed5855e5b42d5ca8a443.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01half__t_4f30ee91f7bb3844ff7579c68d078818.html">cutlass::arch::Mma&lt; gemm::GemmShape&lt; 1, 1, 1 &gt;, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd &gt;</a></td></tr>
<tr class="memdesc:"><td class="mdescLeft">&#160;</td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01half__t_4f30ee91f7bb3844ff7579c68d078818.html#details">More...</a><br /></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table><table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
Namespaces</h2></td></tr>
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top"> &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top"> &#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
<tr class="separator:"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
</div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>
Generated by &#160;<a href="http://www.doxygen.org/index.html">
<img class="footer" src="doxygen.png" alt="doxygen"/>
</a> 1.8.11
</small></address>
</body>
</html>

View File

@ -0,0 +1 @@
988e6466c703c4e63c9a889b8c3c54b5

View File

@ -0,0 +1 @@
03f1613fdffbd6e7575de0d2967d08bf

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More