Compare commits
36 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1ab1027954 | |||
| 86931fef85 | |||
| e33d90b361 | |||
| 96dab34ad9 | |||
| 7c0cd26d13 | |||
| 45ecbc885b | |||
| 8aca98f9a7 | |||
| f4d9c8f755 | |||
| fb335f6a5f | |||
| b5cab177a9 | |||
| eb41735933 | |||
| fb8b3a98b7 | |||
| d9d357877f | |||
| e18292db46 | |||
| fe3438a3c1 | |||
| 877bdcace6 | |||
| 19a9d64e3c | |||
| 80e6f7c860 | |||
| 822b0952cd | |||
| ed2ed4d667 | |||
| 4db423c40f | |||
| b2bc0d3b79 | |||
| 74df0331f2 | |||
| 2332df492e | |||
| cfe4b933ef | |||
| 6877595a5e | |||
| 69e3709da4 | |||
| d419094c28 | |||
| 1a7ac522f8 | |||
| bf6eec53eb | |||
| 206e38dac5 | |||
| d85f6a1cec | |||
| 0826572c4c | |||
| 77d1e0ca81 | |||
| d7137f9c0a | |||
| 461f417b9d |
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +0,0 @@
|
||||
[submodule "tools/external/googletest"]
|
||||
path = tools/external/googletest
|
||||
url = https://github.com/google/googletest.git
|
||||
|
||||
127
CHANGELOG.md
Normal file
127
CHANGELOG.md
Normal file
@ -0,0 +1,127 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
# CUTLASS 2.x
|
||||
|
||||
## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
|
||||
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
* Fast Tensor Core operations:
|
||||
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
|
||||
* Tensor Float 32, BFloat16, and double-precision data types
|
||||
* Mixed integer data types (int8, int4, bin1)
|
||||
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
|
||||
* Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required)
|
||||
* Features:
|
||||
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
|
||||
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
|
||||
* Gaussian complex GEMMs using 3m complex multiply algorithm
|
||||
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
|
||||
* Policy updates:
|
||||
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
|
||||
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`
|
||||
|
||||
## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
|
||||
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
* API to launch compiled kernel instances for GEMM and planar complex GEMM
|
||||
* Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores
|
||||
* Computes complex matrix products on matrices stored as disjoint real and imaginary parts
|
||||
* [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu)
|
||||
* Minor enhancements and bug fixes
|
||||
|
||||
## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19)
|
||||
* Substantially refactored for
|
||||
* Better performance, particularly for native Turing Tensor Cores
|
||||
* Robust and durable templates spanning the design space
|
||||
* Encapsulated functionality embodying modern C++11 programming techniques
|
||||
* Optimized containers and data types for efficient, generic, portable device code
|
||||
* Updates to:
|
||||
* [Quick start guide](/media/docs/quickstart.md)
|
||||
* [Documentation](/README.md#documentation)
|
||||
* [Utilities](/media/docs/utilities.md)
|
||||
* [CUTLASS Profiler](/media/docs/profiler.md)
|
||||
* Native Turing Tensor Cores
|
||||
* Efficient GEMM kernels targeting Turing Tensor Cores
|
||||
* Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands
|
||||
* Coverage of existing CUTLASS functionality
|
||||
* GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
|
||||
* Volta Tensor Cores through native mma.sync and through WMMA API
|
||||
* Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
|
||||
* Batched GEMM operations
|
||||
* Complex-valued GEMMs
|
||||
* **Note: a host compiler supporting C++11 or greater is required.**
|
||||
|
||||
# CUTLASS 1.x
|
||||
|
||||
## [1.3.2](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.2) (2019-07-09)
|
||||
* Performance improvement for Volta Tensor Cores TN and TT layouts.
|
||||
|
||||
## [1.3.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.1) (2019-04-09)
|
||||
* Corrected NVRTC unit tests.
|
||||
|
||||
## [1.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.3.0) (2019-03-20)
|
||||
* Efficient GEMM kernel targeting Volta Tensor Cores via `mma.sync` instruction added in CUDA 10.1.
|
||||
|
||||
## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
|
||||
* Parallelized reductions across threadblocks ("Split-K")
|
||||
* Improved IGEMM performance
|
||||
* Batched strided WMMA GEMMs
|
||||
|
||||
## [1.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.1.0) (2018-09-19)
|
||||
* Turing Features
|
||||
* WMMA GEMM targeting TensorCores - INT8, INT4, 1-bit
|
||||
* Batched Strided GEMM
|
||||
* Threadblock rasterization strategies
|
||||
* Improved performance for adverse problem sizes and data layouts
|
||||
* Extended CUTLASS Core comonents
|
||||
* Tensor views support arbitrary matrix and tensor layouts
|
||||
* Zip iterators for structuring multiple data streams
|
||||
* Enhanced CUTLASS utilities
|
||||
* Reference code for tensor operations in host and device code
|
||||
* Added HostMatrix<> for simplified matrix creation
|
||||
* Examples
|
||||
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
|
||||
|
||||
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
|
||||
|
||||
* Intra-threadblock reduction added for small threadblock tile sizes
|
||||
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
|
||||
* igemm_32x32x128
|
||||
* GEMM _K_ residue handled during prologue prior to mainloop
|
||||
* Replaced Google Test copy with submodule. Use `git submodule init --recursive --update`
|
||||
|
||||
## [1.0.0](https://github.com/NVIDIA/cutlass/commit/2028ebe120aab22bfd0b2baf8902d4c9627eb33f) (2018-05-16)
|
||||
|
||||
* Substantial rewrite to accommodate new architecture
|
||||
* Kernels: SGEMM, DGEMM, IGEMM, HGEMM, WMMA GEMM
|
||||
* Unit and performance tests
|
||||
|
||||
## [0.0.1](https://github.com/NVIDIA/cutlass/commit/d08ba8ac46e2fa3f745e070c390182edb56b2e91) (2017-12-04)
|
||||
|
||||
* Initial release
|
||||
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
490
CMakeLists.txt
Normal file → Executable file
490
CMakeLists.txt
Normal file → Executable file
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
@ -20,48 +20,120 @@
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.3.0)
|
||||
cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR)
|
||||
|
||||
set(CUTLASS_LANGUAGES CXX)
|
||||
|
||||
# CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it!
|
||||
if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0")
|
||||
list(APPEND CUTLASS_LANGUAGES CUDA)
|
||||
set(CUTLASS_NATIVE_CUDA TRUE)
|
||||
|
||||
macro(cutlass_add_executable)
|
||||
add_executable(${ARGN})
|
||||
endmacro()
|
||||
if(cutlass_LOADED)
|
||||
# If CUTLASS has been previously fetched and loaded, don't do it again.
|
||||
return()
|
||||
else()
|
||||
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
|
||||
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
|
||||
if (WIN32 AND MSVC_VERSION GREATER 1800)
|
||||
message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
|
||||
endif()
|
||||
|
||||
# Fall back to the FindCUDA version to create an executable with CUDA files
|
||||
macro(cutlass_add_executable)
|
||||
cuda_add_executable(${ARGN})
|
||||
endmacro()
|
||||
set(cutlass_LOADED ON)
|
||||
set(CUTLASS_DIR ${CMAKE_CURRENT_SOURCE_DIR} CACHE PATH "CUTLASS Repository Directory")
|
||||
endif()
|
||||
|
||||
project(CUTLASS ${CUTLASS_LANGUAGES})
|
||||
message(STATUS "CMake Version: ${CMAKE_VERSION}")
|
||||
|
||||
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)
|
||||
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
#
|
||||
# CUTLASS 2.x requires C++11
|
||||
#
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_STANDARD 11)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
|
||||
else()
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11)
|
||||
endif()
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE)
|
||||
endif()
|
||||
|
||||
message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")
|
||||
|
||||
if(CUTLASS_ENABLE_HEADERS_ONLY)
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT OFF)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_ENABLE_EXAMPLES_INIT ON)
|
||||
set(CUTLASS_ENABLE_TOOLS_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples")
|
||||
set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools")
|
||||
|
||||
if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME})
|
||||
set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT})
|
||||
else()
|
||||
set(CUTLASS_ENABLE_TESTS_INIT OFF)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
|
||||
|
||||
if (CUTLASS_ENABLE_TESTS)
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
|
||||
if (NOT CUDA_VERSION VERSION_LESS 7.5)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 8.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 9.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 9.2)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 10.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
|
||||
endif()
|
||||
if (NOT CUDA_VERSION VERSION_LESS 11.0)
|
||||
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
|
||||
endif()
|
||||
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
|
||||
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")
|
||||
|
||||
# Special policy introduced in CMake 3.13
|
||||
if (POLICY CMP0076)
|
||||
cmake_policy(SET CMP0076 NEW)
|
||||
endif()
|
||||
|
||||
# check if the configuration is supported
|
||||
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
|
||||
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
|
||||
endif()
|
||||
|
||||
find_package(CUDA)
|
||||
find_package(Doxygen QUIET)
|
||||
include(GNUInstallDirs)
|
||||
|
||||
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Configure CMake variables
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
message(STATUS "CUDA Compilation Architectures: ${CUTLASS_NVCC_ARCHS_ENABLED}")
|
||||
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance
|
||||
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance.
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
|
||||
# We do support Debug or Release builds
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release")
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
|
||||
endif()
|
||||
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CUTLASS_LIBRARY_DEBUG_POSTFIX ".debug" CACHE STRING "Default postfix value for debug libraries")
|
||||
|
||||
if(WIN32)
|
||||
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
|
||||
@ -69,96 +141,281 @@ endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX)
|
||||
|
||||
# Disable warning on Unicode characters
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819)
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/fp:strict)
|
||||
endif(WIN32)
|
||||
|
||||
# Configure CUDA options
|
||||
set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS})
|
||||
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
|
||||
endforeach()
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_KEEP)
|
||||
string(APPEND NVCC_FLAGS " -keep")
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -v)
|
||||
endif()
|
||||
|
||||
if (WIN32 AND CUTLASS_NATIVE_CUDA)
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -lineinfo")
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")
|
||||
|
||||
#
|
||||
# CUTLASS generator cmake configuration
|
||||
#
|
||||
set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.")
|
||||
set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.")
|
||||
|
||||
|
||||
# Test Levels L0, L1, L2
|
||||
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")
|
||||
set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
|
||||
|
||||
#
|
||||
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
|
||||
#
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 10.1)
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT OFF)
|
||||
else()
|
||||
string(APPEND NVCC_FLAGS " -lineinfo")
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON)
|
||||
endif()
|
||||
|
||||
if (UNIX)
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
|
||||
set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL
|
||||
"Enable PTX mma instruction for collective matrix multiply operations.")
|
||||
|
||||
#
|
||||
# NOTE: running with asan and CUDA requires the following environment variable:
|
||||
#
|
||||
# ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0
|
||||
#
|
||||
# without the above environment setting, an error like the following may be generated:
|
||||
#
|
||||
# *** Error: Could not detect active GPU device ID [out of memory]
|
||||
# ...
|
||||
# ==9149==ERROR: LeakSanitizer: detected memory leaks
|
||||
# ...
|
||||
#
|
||||
if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --compiler-options=-fsanitize=address --compiler-options=-fno-omit-frame-pointer)
|
||||
string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
|
||||
endif()
|
||||
|
||||
string(APPEND NVCC_FLAGS_DEBUG " -g")
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -O3")
|
||||
###################################################################################################
|
||||
#
|
||||
# Configure CUDA build options
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
# define NDEBUG for release mode to disable assertions
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
|
||||
|
||||
if (CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
|
||||
else()
|
||||
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
|
||||
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
|
||||
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_TENSOR_CORE_MMA)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
|
||||
# MSVC flow handles caching already, but for other generators we handle it here.
|
||||
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
|
||||
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING)
|
||||
list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1)
|
||||
if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c)
|
||||
elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC"))
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-Wconversion>)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$<BOOL:${UNIX}>:-Xcompiler=-fno-strict-aliasing>)
|
||||
|
||||
# Don't leak lineinfo in release builds
|
||||
if (NOT CMAKE_BUILD_TYPE MATCHES "Release")
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -gmlt)
|
||||
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo)
|
||||
endif()
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" )
|
||||
message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" )
|
||||
endif()
|
||||
|
||||
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0)
|
||||
message(FATAL_ERROR "Clang 7.0+ required for GPU compilation")
|
||||
endif()
|
||||
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)
|
||||
|
||||
string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR)
|
||||
list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR)
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR})
|
||||
|
||||
|
||||
# needed for libcublasLt.so in case it's installed in the same location as libcudart.so
|
||||
# dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags)
|
||||
# Otherwise linker uses RUNPATH and that does not propagate to loaded libs.
|
||||
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags)
|
||||
|
||||
link_libraries(nvidia::cudart)
|
||||
endif()
|
||||
|
||||
function(cutlass_apply_cuda_gencode_flags TARGET)
|
||||
|
||||
set(NVCC_FLAGS)
|
||||
set(CLANG_FLAGS)
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
|
||||
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
|
||||
set(CODES)
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
list(APPEND CODES sm_${ARCH})
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
list(APPEND CODES compute_${ARCH})
|
||||
endif()
|
||||
list(JOIN CODES "," CODES_STR)
|
||||
list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}])
|
||||
endforeach()
|
||||
|
||||
if (CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CXX>:${CLANG_FLAGS}>
|
||||
)
|
||||
else()
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${NVCC_FLAGS}>
|
||||
)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_apply_standard_compile_options TARGET)
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUDA_COMPILE_LANGUAGE CXX)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_CLANG_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_CLANG_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_CLANG_FLAGS_DEBUG})
|
||||
else()
|
||||
set(CUDA_COMPILE_LANGUAGE CUDA)
|
||||
set(_FLAGS ${CUTLASS_CUDA_FLAGS} ${CUTLASS_CUDA_NVCC_FLAGS})
|
||||
set(_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} ${CUTLASS_CUDA_NVCC_FLAGS_RELEASE})
|
||||
set(_FLAGS_RELWITHDEBINFO ${CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${CUTLASS_CUDA_NVCC_FLAGS_RELWITHDEBINFO})
|
||||
set(_FLAGS_DEBUG ${CUTLASS_CUDA_FLAGS_DEBUG} ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG})
|
||||
endif()
|
||||
|
||||
target_compile_options(
|
||||
${TARGET}
|
||||
PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:${_FLAGS}>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELEASE>:${_FLAGS_RELEASE}>>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:RELWITHDEBINFO>:${_FLAGS_RELWITHDEBINFO}>>
|
||||
$<$<COMPILE_LANGUAGE:${CUDA_COMPILE_LANGUAGE}>:$<$<CONFIG:DEBUG>:${_FLAGS_DEBUG}>>
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# The following items should eventually be pushed into cutlass/CMakeLists.txt
|
||||
#
|
||||
|
||||
# GLOB for CUTLASS header files. Should we use a static list instead?
|
||||
file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
|
||||
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
|
||||
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
|
||||
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h)
|
||||
file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h)
|
||||
file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h)
|
||||
|
||||
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
|
||||
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
|
||||
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
|
||||
source_group("cutlass" FILES ${CUTLASS_CORE})
|
||||
###################################################################################################
|
||||
#
|
||||
# Define build targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
source_group(TREE ${CMAKE_CURRENT_SOURCE_DIR}/include REGULAR_EXPRESSION ".*\.h")
|
||||
|
||||
add_library(CUTLASS INTERFACE)
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
target_sources(CUTLASS INTERFACE
|
||||
${CUTLASS_GEMM}
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
)
|
||||
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
|
||||
set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass)
|
||||
|
||||
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
|
||||
|
||||
set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/)
|
||||
|
||||
# The following utility directory is needed even if the tools build is disabled, so it exists here.
|
||||
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
|
||||
|
||||
include_directories(${CUTLASS_INCLUDE_DIR})
|
||||
|
||||
target_compile_features(CUTLASS INTERFACE cxx_std_11)
|
||||
|
||||
if (NOT DEFINED CUTLASS_REVISION)
|
||||
|
||||
find_package(Git QUIET)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
|
||||
RESULT_VARIABLE CUTLASS_REVISION_RESULT
|
||||
OUTPUT_VARIABLE CUTLASS_REVISION
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if (CUTLASS_REVISION_RESULT)
|
||||
message(STATUS "CUTLASS Revision: Unable to detect, Git returned code ${CUTLASS_REVISION_RESULT}.")
|
||||
else()
|
||||
message(STATUS "CUTLASS Revision: ${CUTLASS_REVISION}")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h
|
||||
@ONLY)
|
||||
|
||||
target_include_directories(
|
||||
CUTLASS
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUTLASS_INCLUDE_DIR}>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CUDA_TOOLKIT_ROOT_DIR}/include>
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY
|
||||
${CUTLASS_INCLUDE_DIR}/
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include/
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
install(
|
||||
TARGETS CUTLASS
|
||||
EXPORT NvidiaCutlass
|
||||
PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
# Create a custom target to ensure that the CUTLASS sources are visible in an IDE
|
||||
add_custom_target(cutlass_ide SOURCES
|
||||
${CUTLASS_GEMM}
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
)
|
||||
# Doxygen is available. Generate documentation
|
||||
if (DOXYGEN_FOUND)
|
||||
# DOT is available. Enable graph generation in the documentation
|
||||
if (DOXYGEN_DOT_EXECUTABLE)
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
|
||||
else()
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_DOXYGEN_DOT)
|
||||
@ -177,6 +434,55 @@ if (DOXYGEN_FOUND)
|
||||
)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
# Add common library search paths so executables and libraries can load and run
|
||||
# without LD_LIBRARY_PATH being set.
|
||||
link_libraries(
|
||||
"-Wl,-rpath,'$ORIGIN'"
|
||||
"-Wl,-rpath,'$ORIGIN/../lib64'"
|
||||
"-Wl,-rpath,'$ORIGIN/../lib'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'"
|
||||
"-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'"
|
||||
)
|
||||
endif()
|
||||
|
||||
#add_subdirectory(examples/gemm)
|
||||
add_subdirectory(tools)
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake)
|
||||
|
||||
if (CUTLASS_ENABLE_CUBLAS)
|
||||
target_compile_definitions(CUTLASS INTERFACE CUTLASS_ENABLE_CUBLAS=1)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
if(CUTLASS_ENABLE_TOOLS)
|
||||
add_subdirectory(tools)
|
||||
endif()
|
||||
if(CUTLASS_ENABLE_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_TESTS)
|
||||
include(CTest)
|
||||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/
|
||||
)
|
||||
|
||||
install(
|
||||
EXPORT NvidiaCutlass
|
||||
NAMESPACE nvidia::cutlass::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/
|
||||
FILE NvidiaCutlassTargets.cmake
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake)
|
||||
|
||||
57
CONTRIBUTORS.md
Normal file
57
CONTRIBUTORS.md
Normal file
@ -0,0 +1,57 @@
|
||||

|
||||
|
||||
[README](/README.md#documentation) > **Contributors**
|
||||
|
||||
# CUTLASS Developers and Contributors
|
||||
|
||||
This is the official list of CUTLASS developers and contributors.
|
||||
|
||||
## DEVELOPERS
|
||||
Andrew Kerr
|
||||
Haicheng Wu
|
||||
Manish Gupta
|
||||
Dustyn Blasig
|
||||
Pradeep Ramani
|
||||
Naila Farooqui
|
||||
Piotr Majcher
|
||||
Paul Springer
|
||||
Jin Wang
|
||||
Scott Yokim
|
||||
Markus Hohnerbach
|
||||
Aditya Atluri
|
||||
David Tanner
|
||||
|
||||
## CONTRIBUTORS
|
||||
Timothy Costa
|
||||
Julien Demouth
|
||||
Brian Fahs
|
||||
Michael Goldfarb
|
||||
Mostafa Hagog
|
||||
Fei Hu
|
||||
Alan Kaatz
|
||||
Tina Li
|
||||
Timmy Liu
|
||||
Duane Merrill
|
||||
Kevin Siu
|
||||
Markus Tavenrath
|
||||
John Tran
|
||||
Vicki Wang
|
||||
Junkai Wu
|
||||
Fung Xie
|
||||
Albert Xu
|
||||
Jack Yang
|
||||
Xiuxia Zhang
|
||||
Nick Zhao
|
||||
|
||||
## ACKNOWLEDGEMENTS
|
||||
|
||||
Girish Bharambe
|
||||
Cris Cecka
|
||||
Luke Durant
|
||||
Olivier Giroux
|
||||
Stephen Jones
|
||||
Rishkul Kulkarni
|
||||
Bryce Lelbach
|
||||
Joel McCormack
|
||||
Kyrylo Perelygin
|
||||
|
||||
349
CUDA.cmake
Normal file
349
CUDA.cmake
Normal file
@ -0,0 +1,349 @@
|
||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
if(CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
elseif(CMAKE_VERSION VERSION_LESS 3.12.4)
|
||||
set(CUTLASS_NATIVE_CUDA_INIT OFF)
|
||||
else()
|
||||
set(CUTLASS_NATIVE_CUDA_INIT ON)
|
||||
endif()
|
||||
|
||||
set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow")
|
||||
|
||||
if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH})
|
||||
# For backward compatibility, allow use of CUDA_PATH.
|
||||
set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc)
|
||||
endif()
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA)
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
if(NOT CUDA_VERSION)
|
||||
set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION})
|
||||
endif()
|
||||
if(NOT CUDA_TOOLKIT_ROOT_DIR)
|
||||
get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE)
|
||||
endif()
|
||||
|
||||
else()
|
||||
|
||||
find_package(CUDA REQUIRED)
|
||||
# We workaround missing variables with the native flow by also finding the CUDA toolkit the old way.
|
||||
|
||||
if(NOT CMAKE_CUDA_COMPILER_VERSION)
|
||||
set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION})
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if (CUDA_VERSION VERSION_LESS 9.2)
|
||||
message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.")
|
||||
endif()
|
||||
if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang")
|
||||
set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc)
|
||||
message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}")
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
CUDART_LIBRARY cudart
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET cudart AND CUDART_LIBRARY)
|
||||
|
||||
message(STATUS "CUDART: ${CUDART_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(cudart STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cudart SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cudart ALIAS cudart)
|
||||
|
||||
set_property(
|
||||
TARGET cudart
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDART_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cudart)
|
||||
|
||||
message(STATUS "CUDART: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDART: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
CUDA_DRIVER_LIBRARY cuda
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
lib64/stubs
|
||||
lib/stubs
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY)
|
||||
|
||||
message(STATUS "CUDA Driver: ${CUDA_DRIVER_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(cuda_driver STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(cuda_driver SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cuda_driver ALIAS cuda_driver)
|
||||
|
||||
set_property(
|
||||
TARGET cuda_driver
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUDA_DRIVER_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET cuda_driver)
|
||||
|
||||
message(STATUS "CUDA Driver: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "CUDA Driver: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
find_library(
|
||||
NVRTC_LIBRARY nvrtc
|
||||
PATHS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib/x64
|
||||
lib64
|
||||
lib
|
||||
NO_DEFAULT_PATH
|
||||
# We aren't going to search any system paths. We want to find the runtime
|
||||
# in the CUDA toolkit we're building against.
|
||||
)
|
||||
|
||||
if(NOT TARGET nvrtc AND NVRTC_LIBRARY)
|
||||
|
||||
message(STATUS "NVRTC: ${NVRTC_LIBRARY}")
|
||||
|
||||
if(WIN32)
|
||||
add_library(nvrtc STATIC IMPORTED GLOBAL)
|
||||
# Even though we're linking against a .dll, in Windows you statically link against
|
||||
# the .lib file found under lib/x64. The .dll will be loaded at runtime automatically
|
||||
# from the PATH search.
|
||||
else()
|
||||
add_library(nvrtc SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::nvrtc ALIAS nvrtc)
|
||||
|
||||
set_property(
|
||||
TARGET nvrtc
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${NVRTC_LIBRARY}
|
||||
)
|
||||
|
||||
elseif(TARGET nvrtc)
|
||||
|
||||
message(STATUS "NVRTC: Already Found")
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "NVRTC: Not Found")
|
||||
|
||||
endif()
|
||||
|
||||
include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
|
||||
# Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include
|
||||
# paths by default, so we add it explicitly here.
|
||||
|
||||
function(cutlass_correct_source_file_language_property)
|
||||
if(CUDA_COMPILER MATCHES "clang")
|
||||
foreach(File ${ARGN})
|
||||
if(File MATCHES ".*\.cu$")
|
||||
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
|
||||
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")
|
||||
|
||||
function(cutlass_unify_source_files TARGET_ARGS_VAR)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs BATCH_SOURCES BATCH_SIZE)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED TARGET_ARGS_VAR)
|
||||
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
|
||||
endif()
|
||||
|
||||
if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
|
||||
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
|
||||
endif()
|
||||
|
||||
if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
|
||||
|
||||
set(CUDA_FILE_ARGS)
|
||||
set(TARGET_SOURCE_ARGS)
|
||||
|
||||
foreach(ARG ${__UNPARSED_ARGUMENTS})
|
||||
if(${ARG} MATCHES ".*\.cu$")
|
||||
list(APPEND CUDA_FILE_ARGS ${ARG})
|
||||
else()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${ARG})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
while(NUM_CUDA_FILE_ARGS GREATER 0)
|
||||
list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH)
|
||||
string(SHA256 CUDA_FILE_BATCH_HASH "${CUDA_FILE_BATCH}")
|
||||
string(SUBSTRING ${CUDA_FILE_BATCH_HASH} 0 12 CUDA_FILE_BATCH_HASH)
|
||||
set(BATCH_FILE ${CMAKE_CURRENT_BINARY_DIR}/${NAME}.unity.${CUDA_FILE_BATCH_HASH}.cu)
|
||||
message(STATUS "Generating ${BATCH_FILE}")
|
||||
file(WRITE ${BATCH_FILE} "// Unity File - Auto Generated!\n")
|
||||
foreach(CUDA_FILE ${CUDA_FILE_BATCH})
|
||||
get_filename_component(CUDA_FILE_ABS_PATH ${CUDA_FILE} ABSOLUTE)
|
||||
file(APPEND ${BATCH_FILE} "#include \"${CUDA_FILE_ABS_PATH}\"\n")
|
||||
endforeach()
|
||||
list(APPEND TARGET_SOURCE_ARGS ${BATCH_FILE})
|
||||
if (NUM_CUDA_FILE_ARGS LESS_EQUAL __BATCH_SIZE)
|
||||
break()
|
||||
endif()
|
||||
list(SUBLIST CUDA_FILE_ARGS ${__BATCH_SIZE} -1 CUDA_FILE_ARGS)
|
||||
list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS)
|
||||
endwhile()
|
||||
|
||||
else()
|
||||
|
||||
set(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
endif()
|
||||
|
||||
set(${TARGET_ARGS_VAR} ${TARGET_SOURCE_ARGS} PARENT_SCOPE)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_library NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs EXPORT_NAME)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
if(__EXPORT_NAME)
|
||||
add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME})
|
||||
set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME})
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_add_executable NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
|
||||
if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang")
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
else()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||
cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
endif()
|
||||
|
||||
cutlass_apply_standard_compile_options(${NAME})
|
||||
cutlass_apply_cuda_gencode_flags(${NAME})
|
||||
|
||||
target_compile_features(
|
||||
${NAME}
|
||||
INTERFACE
|
||||
cxx_std_11
|
||||
)
|
||||
|
||||
endfunction()
|
||||
|
||||
function(cutlass_target_sources NAME)
|
||||
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs)
|
||||
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS})
|
||||
cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS})
|
||||
target_sources(${NAME} ${TARGET_SOURCE_ARGS})
|
||||
|
||||
endfunction()
|
||||
24
Doxyfile
24
Doxyfile
@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
|
||||
# title of most generated pages and in a few other places.
|
||||
# The default value is: My Project.
|
||||
|
||||
PROJECT_NAME = "Cutlass"
|
||||
PROJECT_NAME = "CUTLASS"
|
||||
|
||||
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
|
||||
# could be handy for archiving the generated documentation or if some version
|
||||
@ -51,14 +51,14 @@ PROJECT_BRIEF = "CUDA Templates for Linear Algebra Subroutines and Solv
|
||||
# and the maximum width should not exceed 200 pixels. Doxygen will copy the logo
|
||||
# to the output directory.
|
||||
|
||||
PROJECT_LOGO =
|
||||
PROJECT_LOGO = media/images/cutlass-logo-small.png
|
||||
|
||||
# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
|
||||
# into which the generated documentation will be written. If a relative path is
|
||||
# entered, it will be relative to the location where doxygen was started. If
|
||||
# left blank the current directory will be used.
|
||||
|
||||
OUTPUT_DIRECTORY = docs
|
||||
OUTPUT_DIRECTORY = doxygen
|
||||
|
||||
# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
|
||||
# directories (in 2 levels) under the output directory of each output format and
|
||||
@ -206,7 +206,7 @@ SEPARATE_MEMBER_PAGES = NO
|
||||
# uses this value to replace tabs by spaces in code fragments.
|
||||
# Minimum value: 1, maximum value: 16, default value: 4.
|
||||
|
||||
TAB_SIZE = 4
|
||||
TAB_SIZE = 2
|
||||
|
||||
# This tag can be used to specify a number of aliases that act as commands in
|
||||
# the documentation. An alias has the form:
|
||||
@ -297,7 +297,7 @@ AUTOLINK_SUPPORT = YES
|
||||
# diagrams that involve STL classes more complete and accurate.
|
||||
# The default value is: NO.
|
||||
|
||||
BUILTIN_STL_SUPPORT = NO
|
||||
BUILTIN_STL_SUPPORT = YES
|
||||
|
||||
# If you use Microsoft's C++/CLI language, you should set this option to YES to
|
||||
# enable parsing support.
|
||||
@ -734,7 +734,9 @@ WARN_LOGFILE =
|
||||
# spaces.
|
||||
# Note: If this tag is empty the current directory is searched.
|
||||
|
||||
INPUT = cutlass
|
||||
INPUT = include/cutlass tools/util/include/cutlass/ tools/library/include/cutlass/
|
||||
|
||||
INPUT += media/docs/doxygen_mainpage.md
|
||||
|
||||
# This tag can be used to specify the character encoding of the source files
|
||||
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
|
||||
@ -870,7 +872,7 @@ FILTER_SOURCE_PATTERNS =
|
||||
# (index.html). This can be useful if you have a project on for instance GitHub
|
||||
# and want to reuse the introduction page also for the doxygen output.
|
||||
|
||||
USE_MDFILE_AS_MAINPAGE =
|
||||
USE_MDFILE_AS_MAINPAGE = media/docs/doxygen_mainpage.md
|
||||
|
||||
#---------------------------------------------------------------------------
|
||||
# Configuration options related to source browsing
|
||||
@ -999,7 +1001,7 @@ GENERATE_HTML = YES
|
||||
# The default directory is: html.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_OUTPUT = generated-html
|
||||
HTML_OUTPUT =
|
||||
|
||||
# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
|
||||
# generated HTML page (for example: .htm, .php, .asp).
|
||||
@ -1080,7 +1082,7 @@ HTML_EXTRA_FILES =
|
||||
# Minimum value: 0, maximum value: 359, default value: 220.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_COLORSTYLE_HUE = 82
|
||||
HTML_COLORSTYLE_HUE = 100
|
||||
|
||||
# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
|
||||
# in the HTML output. For a value of 0 the output will use grayscales only. A
|
||||
@ -1088,7 +1090,7 @@ HTML_COLORSTYLE_HUE = 82
|
||||
# Minimum value: 0, maximum value: 255, default value: 100.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_COLORSTYLE_SAT = 100
|
||||
HTML_COLORSTYLE_SAT = 50
|
||||
|
||||
# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
|
||||
# luminance component of the colors in the HTML output. Values below 100
|
||||
@ -1107,7 +1109,7 @@ HTML_COLORSTYLE_GAMMA = 80
|
||||
# The default value is: YES.
|
||||
# This tag requires that the tag GENERATE_HTML is set to YES.
|
||||
|
||||
HTML_TIMESTAMP = YES
|
||||
HTML_TIMESTAMP = NO
|
||||
|
||||
# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
|
||||
# documentation will contain sections that can be hidden and shown after the
|
||||
|
||||
23
LICENSE.TXT
23
LICENSE.TXT
@ -1,23 +0,0 @@
|
||||
Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
23
LICENSE.txt
Normal file
23
LICENSE.txt
Normal file
@ -0,0 +1,23 @@
|
||||
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
330
README.md
330
README.md
@ -1,10 +1,10 @@
|
||||

|
||||
|
||||
# CUTLASS 1.0
|
||||
# CUTLASS 2.2
|
||||
|
||||
_CUTLASS 1.0.1 - June 2018_
|
||||
_CUTLASS 2.2 - June 2020_
|
||||
|
||||
CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
@ -16,20 +16,46 @@ and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for 8-bit integer, half-precision floating
|
||||
point (FP16), single-precision floating point (FP32), and double-precision floating
|
||||
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond.
|
||||
multiply-accumulate abstractions for half-precision floating
|
||||
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
|
||||
single-precision floating point (FP32), double-precision floating
|
||||
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
|
||||
|
||||
CUTLASS 1.0 has changed substantially from our preview release described in
|
||||
the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda)
|
||||
post. We have decomposed the structure of the GEMM computation into deeper, structured
|
||||
primitives for loading data, computing predicate masks, streaming data at each level of
|
||||
the GEMM hierarchy, and updating the output matrix.
|
||||
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
|
||||
targeting the programmable, high-throughput _Tensor Cores_ implemented by
|
||||
NVIDIA's Volta, Turing, and Ampere architectures.
|
||||
|
||||
CUTLASS 1.0 is described in the [Doxygen documentation](https://nvidia.github.io/cutlass)
|
||||
and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.
|
||||
|
||||
See the [functionality listing](media/docs/functionality.md) for the list of operations
|
||||
supported at each level of the execution model hierarchy.
|
||||
|
||||
# What's New in CUTLASS 2.2
|
||||
|
||||
CUTLASS 2.2 is a significant update to CUTLASS adding:
|
||||
|
||||
- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
|
||||
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
|
||||
- Deep software pipelines using asynchronous copy
|
||||
- Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745)
|
||||
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)
|
||||
|
||||
# What's New in CUTLASS 2.1
|
||||
|
||||
CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:
|
||||
|
||||
- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
|
||||
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
|
||||
|
||||
# What's New in CUTLASS 2.0
|
||||
|
||||
CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
|
||||
|
||||
- Better performance over 1.x, particularly for kernels targeting Turing Tensor Cores
|
||||
- Robust and durable templates that reliably span the design space
|
||||
- Encapsulated functionality that may be reusable in other contexts
|
||||
|
||||
**See the [CHANGELOG](CHANGELOG.md) for more details.**
|
||||
|
||||
# Performance
|
||||
|
||||
@ -38,164 +64,256 @@ and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechcon
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
|
||||
when compiled with CUDA 9.2.
|
||||
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
|
||||
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
|
||||
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
|
||||
CUTLASS requires a C++11 host compiler and
|
||||
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
|
||||
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.
|
||||
|
||||
We have tested the following environments.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 14.04 | GCC 4.8.2 |
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
| Ubuntu 18.04 | GCC 7.5.0 |
|
||||
|
||||
Additionally, CUTLASS may be built with clang.
|
||||
See [these instructions](media/docs/quickstart.md#clang) for more details.
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
|
||||
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.
|
||||
|
||||
|**GPU**|
|
||||
|---|
|
||||
|NVIDIA GeForce 1080|
|
||||
|NVIDIA TitanXP|
|
||||
|NVIDIA Tesla P100|
|
||||
|NVIDIA Tesla V100|
|
||||
|NVIDIA TitanV|
|
||||
|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|
||||
|---|---|---|---|
|
||||
|NVIDIA Tesla P100|6.0|9.2| |
|
||||
|NVIDIA GeForce 1080|6.1|9.2| |
|
||||
|NVIDIA TitanXP|6.1|9.2| |
|
||||
|NVIDIA Tesla V100|7.0|9.2|10.1|
|
||||
|NVIDIA TitanV|7.0|9.2|10.1|
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|
||||
|NVIDIA Tesla T4|7.5|10.0|10.2|
|
||||
|NVIDIA A100|8.0|11.0|11.0|
|
||||
|
||||
# Documentation
|
||||
|
||||
CUTLASS 2.2 is described in the following documents and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
|
||||
- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
|
||||
- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS
|
||||
- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA
|
||||
- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts
|
||||
- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project
|
||||
- [Terminology](media/docs/terminology.md) - describes terms used in the code
|
||||
- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++
|
||||
- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays
|
||||
- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory
|
||||
- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory
|
||||
- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application
|
||||
- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development
|
||||
|
||||
We have also described the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
CUTLASS is a header-only template library and does not need to be built to be used by other
|
||||
projects. However, we distribute extensive unit tests and utility programs to demonstrate
|
||||
CUTLASS. These instructions are for building those test programs.
|
||||
projects. Client applications should target CUTLASS's `include/` directory in their include
|
||||
paths.
|
||||
|
||||
CUTLASS's unit tests depend on Google Test which exists as a git submodule. You can fetch
|
||||
submodules as follows.
|
||||
CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12.
|
||||
Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed
|
||||
on your system.
|
||||
|
||||
```
|
||||
$ git submodule update --init --recursive
|
||||
$ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
|
||||
```
|
||||
|
||||
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify
|
||||
Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake once.
|
||||
|
||||
```
|
||||
$ mkdir build && cd build
|
||||
$ cmake ..
|
||||
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture
|
||||
```
|
||||
|
||||
Compile the CUTLASS project by running Make. Include the -j argument to compile sources in
|
||||
parallel and speed up the build process.
|
||||
From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make.
|
||||
|
||||
The unit tests are organized as several binaries mirroring the top-level namespaces of CUTLASS,
|
||||
and they may be executed in parallel via make's `-j` command line argument.
|
||||
|
||||
```
|
||||
$ make -j12
|
||||
...
|
||||
$
|
||||
```
|
||||
|
||||
Verify CUTLASS has been built correctly by running the unit tests from the build/ directory.
|
||||
|
||||
```
|
||||
$ ./tools/test/unit/cutlass_unit_test
|
||||
$ make test_unit -j
|
||||
...
|
||||
...
|
||||
...
|
||||
[----------] Global test environment tear-down
|
||||
[==========] 481 tests from 24 test cases ran. (5954 ms total)
|
||||
[ PASSED ] 481 tests.
|
||||
[==========] 946 tests from 57 test cases ran. (10812 ms total)
|
||||
[ PASSED ] 946 tests.
|
||||
```
|
||||
|
||||
All tests should pass, though the exact number of tests may vary over time.
|
||||
All tests should pass on supported platforms, though the exact number of tests may vary over time.
|
||||
|
||||
|
||||
# Project Structure
|
||||
|
||||
CUTLASS is arranged as a header-only library with several example test programs
|
||||
that demonstrate instantiating a GEMM task within a CUDA kernel. The Doxygen documentation
|
||||
provides a complete list of files, classes, and template concepts defined in the CUTLASS
|
||||
project. A brief summary is described below.
|
||||
CUTLASS is arranged as a header-only library along with Utilities, Tools, Examples, and unit tests.
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass) provides a complete list of files, classes,
|
||||
and template concepts defined in the CUTLASS project.
|
||||
|
||||
The CUTLASS library is defined in the cutlass/ directory and consists of CUDA C++ template
|
||||
classes and other definitions for implementing efficient GPU GEMM kernels. A set of core
|
||||
classes and templates define basic primitives that are then applied to compute GEMM via
|
||||
templates in the cutlass/gemm directory.
|
||||
A detailed explanation of the source code organization may be found in the
|
||||
[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below.
|
||||
|
||||
## CUTLASS Template Library
|
||||
|
||||
```
|
||||
cutlass/
|
||||
gemm/
|
||||
util/
|
||||
<core API components>
|
||||
include/ # client applications should target this directory in their build's include paths
|
||||
|
||||
cutlass/ # CUDA Templates for Linear Algebra Subroutines and Solvers - headers only
|
||||
|
||||
arch/ # direct exposure of architecture features (including instruction-level GEMMs)
|
||||
|
||||
gemm/ # code specialized for general matrix product computations
|
||||
|
||||
layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory
|
||||
|
||||
platform/ # CUDA-capable Standard Library components
|
||||
|
||||
reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model
|
||||
|
||||
transform/ # code specialized for layout, type, and domain transformations
|
||||
|
||||
* # core vocabulary types, containers, and basic numeric operations
|
||||
```
|
||||
|
||||
Several tools and test programs are also distributed with the CUTLASS library. They are
|
||||
contained in the following directories.
|
||||
### CUTLASS SDK Examples
|
||||
|
||||
[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations.
|
||||
|
||||
```
|
||||
examples/
|
||||
00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs
|
||||
|
||||
01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors
|
||||
|
||||
02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents
|
||||
|
||||
03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS
|
||||
|
||||
04_tile_iterator/ # example demonstrating an iterator over tiles in memory
|
||||
|
||||
05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation
|
||||
|
||||
06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel
|
||||
|
||||
07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores
|
||||
|
||||
08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores
|
||||
|
||||
10_planar_complex/ # example demonstrating planar complex GEMM kernels
|
||||
|
||||
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
|
||||
|
||||
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
|
||||
|
||||
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
|
||||
```
|
||||
|
||||
### Tools
|
||||
```
|
||||
tools/
|
||||
test/
|
||||
unit/
|
||||
core/
|
||||
gemm/
|
||||
perf/
|
||||
util/
|
||||
<utilities>
|
||||
library/ # CUTLASS Instance Library - contains instantiations of all supported CUTLASS templates
|
||||
include/
|
||||
cutlass/
|
||||
library/
|
||||
|
||||
profiler/ # CUTLASS Profiler - command-line utility for executing operations in the
|
||||
# CUTLASS Library
|
||||
|
||||
util/ # CUTLASS Utilities - contains numerous helper classes for
|
||||
include/ # manging tensors in device memory, reference
|
||||
cutlass/ # implementations for GEMM, random initialization
|
||||
util/ # of tensors, and I/O.
|
||||
```
|
||||
|
||||
### Test
|
||||
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md).
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
Its usage is shown below.
|
||||
|
||||
Program usage:
|
||||
The `tools/profiler/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
It can be built as follows:
|
||||
|
||||
```
|
||||
cutlass_perf_test [options]
|
||||
|
||||
--help
|
||||
--append=<true|false*> If true, appends output to existing CSV file. If false, overwrites.
|
||||
--alpha=<alpha> Value for alpha to be used in GEMM experiments
|
||||
--beta=<beta> Value for beta to be used in GEMM experiments
|
||||
--dist=<distribution> Describes the random distribution of each of the input matrix operands.
|
||||
--execution_mode=<mode> Specifies execution mode: profile, verify, single
|
||||
--output=<filename.csv> Writes summary of profiling to specified .csv file
|
||||
--iterations=<timing iterations> maximum number of iterations to execute when profiling
|
||||
--m=<height>[:max height[:step]] Height of GEMM problem (number of rows of C). May specify a range with optional step size.
|
||||
--n=<width>[:max width[:step]] Width of GEMM problem (number of columns of C). May specify a range with optional step size.
|
||||
--k=<depth>[:max depth[:step]] Size of inner dimension of A and B. May specify a range with optional step size.
|
||||
--kernels=<{s|d|h|i|wmma_}gemm_{nn,nt,tn,tt}> Select GEMM datatype and layout to use for tests
|
||||
--peak=<bool> If true, only reports peak performance per kernel after profiling specified problem space.
|
||||
--save_workspace={*never,incorrect,always} Specifies when to save the GEMM inputs and results to the filesystem.
|
||||
--seed=<seed> Random seed used by the random number generator in initializing input matrices.
|
||||
--tags=<column:tag,...> Inserts leading columns in output table and uniform values for each column.
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
# Runs one problem size for all kernels
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024
|
||||
|
||||
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
|
||||
$ make cutlass_profiler -j
|
||||
```
|
||||
|
||||
To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout.
|
||||
To instantiate all, set the following environment variable when running CMake from an empty `build/` directory.
|
||||
```
|
||||
$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all
|
||||
...
|
||||
$ make cutlass_profiler -j
|
||||
```
|
||||
|
||||
Example command line for profiling SGEMM kernels is as follows:
|
||||
```
|
||||
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
|
||||
|
||||
=============================
|
||||
Problem ID: 1
|
||||
|
||||
Provider: CUTLASS
|
||||
OperationKind: gemm
|
||||
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
|
||||
|
||||
Status: Success
|
||||
Verification: ON
|
||||
Disposition: Passed
|
||||
|
||||
cuBLAS: Passed
|
||||
|
||||
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
|
||||
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
|
||||
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
|
||||
|
||||
Bytes: 180355072 bytes
|
||||
FLOPs: 115992428544 flops
|
||||
|
||||
Runtime: 6.73655 ms
|
||||
Memory: 24.934 GiB/s
|
||||
|
||||
Math: 17218.4 GFLOP/s
|
||||
```
|
||||
|
||||
[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
|
||||
|
||||
|
||||
# About
|
||||
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
3-clause "New" BSD license.
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
[3-clause "New" BSD license](LICENSE.txt).
|
||||
|
||||
# Contributors
|
||||
|
||||
The official list of CUTLASS developers and contributors is available here: [CONTRIBUTORS](CONTRIBUTORS.md).
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
|
||||
@ -13,8 +13,8 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message("Create header file for ${FILE_IN}")
|
||||
message("Create header file for ${FILE_OUT}")
|
||||
# message("Create header file for ${FILE_IN}")
|
||||
# message("Create header file for ${FILE_OUT}")
|
||||
file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
|
||||
|
||||
set(RESULT "#pragma once\n")
|
||||
47
changelog.md
47
changelog.md
@ -1,47 +0,0 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
|
||||
|
||||
* Intra-threadblock reduction added for small threadblock tile sizes
|
||||
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
|
||||
* igemm_32x32x128
|
||||
* GEMM _K_ residue handled during prologue prior to mainloop
|
||||
* Replaced Google Test copy with submodule. Use `git submodule init`
|
||||
|
||||
## [1.0.0](https://github.com/NVIDIA/cutlass/commit/2028ebe120aab22bfd0b2baf8902d4c9627eb33f) (2018-05-16)
|
||||
|
||||
* Substantial rewrite to accommodate new architecture
|
||||
* Kernels: SGEMM, DGEMM, IGEMM, HGEMM, WMMA GEMM
|
||||
* Unit and performance tests
|
||||
|
||||
## [0.0.1](https://github.com/NVIDIA/cutlass/commit/d08ba8ac46e2fa3f745e070c390182edb56b2e91) (2017-12-04)
|
||||
|
||||
* Initial release
|
||||
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
function formatFiles {
|
||||
for f in `find "$1" -type f -name "*.$2"` ; do
|
||||
COMMAND="clang-format -i $f"
|
||||
echo $COMMAND
|
||||
$COMMAND
|
||||
done
|
||||
}
|
||||
|
||||
formatFiles "cutlass" "h"
|
||||
formatFiles "tools/test" "h"
|
||||
formatFiles "tools/test" "cpp"
|
||||
formatFiles "tools/util" "h"
|
||||
|
||||
7
cmake/NvidiaCutlassConfig.cmake
Normal file
7
cmake/NvidiaCutlassConfig.cmake
Normal file
@ -0,0 +1,7 @@
|
||||
get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
||||
|
||||
include(CMakeFindDependencyMacro)
|
||||
|
||||
if(NOT TARGET nvidia::cutlass::CUTLASS)
|
||||
include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")
|
||||
endif()
|
||||
14
cmake/NvidiaCutlassPackageConfig.cmake
Normal file
14
cmake/NvidiaCutlassPackageConfig.cmake
Normal file
@ -0,0 +1,14 @@
|
||||
set(CPACK_PACKAGE_NAME NvidiaCutlass)
|
||||
set(CPACK_PACKAGE_VENDOR NVIDIA)
|
||||
set(CPACK_PACKAGE_CONTACT info@nvidia.com)
|
||||
set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CUTLASS CUDA C++ Template Linear Algebra Library")
|
||||
set(CPACK_PACKAGE_INSTALL_DIRECTORY ${CPACK_PACKAGE_NAME})
|
||||
set(CPACK_PACKAGE_VERSION_MAJOR ${PROJECT_VERSION_MAJOR})
|
||||
set(CPACK_PACKAGE_VERSION_MINOR ${PROJECT_VERSION_MINOR})
|
||||
set(CPACK_PACKAGE_VERSION_PATCH ${PROJECT_VERSION_PATCH})
|
||||
set(CPACK_VERBATIM_VARIABLES YES)
|
||||
# set(CPACK_PACKAGE_DESCRIPTION_FILE ${CMAKE_CURRENT_LIST_DIR}/Description.txt)
|
||||
# set(CPACK_RESOURCE_FILE_WELCOME ${CMAKE_CURRENT_LIST_DIR}/Welcome.txt)
|
||||
# set(CPACK_RESOURCE_FILE_LICENSE ${CMAKE_CURRENT_LIST_DIR}/License.txt)
|
||||
# set(CPACK_RESOURCE_FILE_README ${CMAKE_CURRENT_LIST_DIR}/Readme.txt)
|
||||
include(CPack)
|
||||
23
cmake/googletest.cmake
Normal file
23
cmake/googletest.cmake
Normal file
@ -0,0 +1,23 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against")
|
||||
|
||||
if(GOOGLETEST_DIR)
|
||||
set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override")
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
googletest
|
||||
GIT_REPOSITORY https://github.com/google/googletest.git
|
||||
GIT_TAG 0fe9660
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(googletest)
|
||||
|
||||
if(NOT googletest_POPULATED)
|
||||
FetchContent_Populate(googletest)
|
||||
if (MSVC)
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endif()
|
||||
43
cmake/nop.cu
Normal file
43
cmake/nop.cu
Normal file
@ -0,0 +1,43 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Basic CUDA file for testing compiler flags.
|
||||
*/
|
||||
|
||||
__device__ int inner()
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
__global__ void test()
|
||||
{
|
||||
inner();
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
test<<<1,1>>>();
|
||||
return 0;
|
||||
}
|
||||
38
cmake/version.h.in
Normal file
38
cmake/version.h.in
Normal file
@ -0,0 +1,38 @@
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@
|
||||
#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@
|
||||
#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@
|
||||
#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
inline uint32_t getVersion() {
|
||||
return CUTLASS_VERSION;
|
||||
}
|
||||
inline uint32_t getVersionMajor() {
|
||||
return CUTLASS_MAJOR;
|
||||
}
|
||||
inline uint32_t getVersionMinor() {
|
||||
return CUTLASS_MINOR;
|
||||
}
|
||||
inline uint32_t getVersionPatch() {
|
||||
return CUTLASS_PATCH;
|
||||
}
|
||||
inline uint32_t getVersionBuild() {
|
||||
return CUTLASS_BUILD + 0;
|
||||
}
|
||||
inline std::string getVersionString() {
|
||||
std::string version = "@CUTLASS_VERSION@";
|
||||
if (getVersionBuild()) {
|
||||
version += "." + std::to_string(getVersionBuild());
|
||||
}
|
||||
return version;
|
||||
}
|
||||
inline std::string getGitRevision() {
|
||||
return "@CUTLASS_REVISION@";
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
125
cuBLAS.cmake
Normal file
125
cuBLAS.cmake
Normal file
@ -0,0 +1,125 @@
|
||||
|
||||
message(STATUS "Configuring cublas ...")
|
||||
|
||||
if((DEFINED CUTLASS_ENABLE_CUBLAS AND NOT CUTLASS_ENABLE_CUBLAS) OR
|
||||
(DEFINED CUBLAS_ENABLED AND NOT CUBLAS_ENABLED))
|
||||
|
||||
# Don't add cuBLAS if it's defined and false, assume it's not found.
|
||||
|
||||
set(CUBLAS_FOUND OFF)
|
||||
message(STATUS "cuBLAS Disabled.")
|
||||
|
||||
elseif(NOT TARGET cublas)
|
||||
|
||||
find_path(
|
||||
_CUBLAS_INCLUDE_DIR
|
||||
NAMES cublas.h
|
||||
HINTS
|
||||
${CUBLAS_INCLUDE_PATH}
|
||||
ENV CUBLAS_INCLUDE_PATH
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
include
|
||||
)
|
||||
|
||||
find_library(
|
||||
_CUBLAS_LIBRARY
|
||||
NAMES cublas
|
||||
HINTS
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLAS_INCLUDE_DIR AND _CUBLAS_LIBRARY)
|
||||
|
||||
message(STATUS "cuBLAS: ${_CUBLAS_LIBRARY}")
|
||||
message(STATUS "cuBLAS: ${_CUBLAS_INCLUDE_DIR}")
|
||||
|
||||
set(CUBLAS_FOUND ON CACHE INTERNAL "cublas Library Found")
|
||||
set(CUBLAS_LIBRARY ${_CUBLAS_LIBRARY})
|
||||
set(CUBLAS_INCLUDE_DIR ${_CUBLAS_INCLUDE_DIR})
|
||||
|
||||
else()
|
||||
|
||||
message(STATUS "cublas not found.")
|
||||
set(CUBLAS_FOUND OFF CACHE INTERNAL "cublas Library Found")
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(CUTLASS_ENABLE_CUBLAS ${CUBLAS_FOUND} CACHE BOOL "Enable CUTLASS to build with cuBLAS library.")
|
||||
|
||||
if(CUTLASS_ENABLE_CUBLAS AND NOT CUBLAS_FOUND)
|
||||
message(FATAL_ERROR "CUTLASS_ENABLE_CUBLAS enabled but cuBLAS library could not be found.")
|
||||
endif()
|
||||
|
||||
if(CUTLASS_ENABLE_CUBLAS AND NOT TARGET cublas)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublas STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublas SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
add_library(nvidia::cublas ALIAS cublas)
|
||||
|
||||
set_property(
|
||||
TARGET cublas
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${CUBLAS_LIBRARY})
|
||||
|
||||
target_include_directories(
|
||||
cublas
|
||||
INTERFACE
|
||||
$<INSTALL_INTERFACE:include>
|
||||
$<BUILD_INTERFACE:${CUBLAS_INCLUDE_DIR}>)
|
||||
|
||||
find_library(
|
||||
_CUBLASLT_LIBRARY
|
||||
NAMES cublasLt
|
||||
HINTS
|
||||
${CUBLAS_LIBRARY_PATH}
|
||||
ENV CUBLAS_LIBRARY_PATH
|
||||
${_CUBLAS_INCLUDE_DIR}/..
|
||||
${CUBLAS_PATH}
|
||||
ENV CUBLAS_PATH
|
||||
${CUDA_TOOLKIT_ROOT_DIR}
|
||||
PATH_SUFFIXES
|
||||
lib64
|
||||
lib/x64
|
||||
lib
|
||||
)
|
||||
|
||||
if(_CUBLASLT_LIBRARY AND NOT TARGET cublasLt)
|
||||
|
||||
if(WIN32)
|
||||
add_library(cublasLt STATIC IMPORTED GLOBAL)
|
||||
else()
|
||||
add_library(cublasLt SHARED IMPORTED GLOBAL)
|
||||
endif()
|
||||
|
||||
set_property(
|
||||
TARGET cublasLt
|
||||
PROPERTY IMPORTED_LOCATION
|
||||
${_CUBLASLT_LIBRARY})
|
||||
|
||||
add_library(nvidia::cublasLt ALIAS cublasLt)
|
||||
|
||||
target_link_libraries(cublas INTERFACE cublasLt)
|
||||
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
message(STATUS "Configuring cuBLAS ... done.")
|
||||
@ -1,102 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Defines conversion operations among Fragments of different base type.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputFragment_>
|
||||
struct Convert {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputScalar_, int kScalars_>
|
||||
struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
|
||||
/// The input fragment.
|
||||
typedef Fragment<InputScalar_, kScalars_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<OutputScalar_, kScalars_> OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Convert() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
for (int i = 0; i < kScalars_; ++i) {
|
||||
dst[i] = static_cast<OutputScalar_>(src[i + offset]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_>
|
||||
struct Copy {
|
||||
/// The input fragment.
|
||||
typedef Fragment_ InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment_ OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Copy() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename InputFragment_>
|
||||
CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
|
||||
if (sizeof(typename Fragment_::Element) == 8) {
|
||||
uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
|
||||
uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
} else {
|
||||
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
287
cutlass/coord.h
287
cutlass/coord.h
@ -1,287 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief A Coord is a coordinate of arbitrary rank into a tensor or matrix
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Describes identity elements
|
||||
struct Identity {
|
||||
/// Enumeration describing identity elements. Value assignments are significant.
|
||||
/// Feel free to add or multiply by these, respectively.
|
||||
enum Kind { Additive = 0, Multiplicative = 1 };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically-sized array specifying Coords within a tensor
|
||||
template <int N_>
|
||||
struct Coord {
|
||||
//
|
||||
// Type and constant definitions
|
||||
//
|
||||
|
||||
static int const N = N_;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Indices
|
||||
int idx[N];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor initializes uniformly
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int value = 0) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int _idx[]) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = _idx[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator+(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] + b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator-(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] - b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator*(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] * b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator/(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] / b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator+=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] += b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator-=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] -= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator*=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] *= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator/=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] /= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
|
||||
T sum = T(0);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int& at() {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int& at(int dim) { return idx[dim]; }
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int const& at() const {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& at(int dim) const { return idx[dim]; }
|
||||
|
||||
/// Determines if two Coord<> objects are equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Coord<N> const& b) const {
|
||||
bool equal = true;
|
||||
for (int i = 0; equal && i < N; ++i) {
|
||||
equal = (idx[i] == b.idx[i]);
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
/// Not equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Coord<N> const& b) const { return !(*this == b); }
|
||||
|
||||
/// Clamps a coordinate to a range specified by maximum and minimum values
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& clamp(Coord<N> const& max, Coord<N> const& min = Coord<N>()) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the product of all elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int count() const {
|
||||
int product = idx[0];
|
||||
for (int i = 1; i < N; ++i) {
|
||||
product *= idx[i];
|
||||
}
|
||||
return product;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<1> make_Coord(int _0) {
|
||||
int values[1] = {_0};
|
||||
return Coord<1>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> make_Coord(int _0, int _1) {
|
||||
int values[2] = {_0, _1};
|
||||
return Coord<2>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 3-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> make_Coord(int _0, int _1, int _2) {
|
||||
int values[3] = {_0, _1, _2};
|
||||
return Coord<3>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 4-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
int values[4] = {_0, _1, _2, _3};
|
||||
return Coord<4>(values);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,44 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief Helpers for printing cutlass/core objects
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
|
||||
template <int Rank>
|
||||
std::ostream& operator<<(std::ostream& out, cutlass::Coord<Rank> const& coord) {
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
out << (i ? ", " : "") << coord.idx[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -1,73 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Basic include for CUTLASS macros
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 0
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#ifdef __NVCC__
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#else
|
||||
#define CUTLASS_HOST_DEVICE
|
||||
// CUTLASS_DEVICE is an error if not compiling device code
|
||||
#endif
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(_MSC_VER)
|
||||
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
||||
#endif
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif
|
||||
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// NVIDIA GPU Warp size
|
||||
static const int kWarpSize = 32;
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,278 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines Fragment, a statically-sized array for storing parts of matrices within a
|
||||
thread's registers.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/util/cutlass_math.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_concept Fragment Concept
|
||||
@{
|
||||
|
||||
\ref fragment_concept is a statically sized array for storing parts of tiles held by individual CUDA
|
||||
threads.
|
||||
|
||||
@par \ref fragment_concept
|
||||
Types satisfying \ref fragment_concept define the following members
|
||||
- <b>Element</b> - type of each access held within the fragment
|
||||
- <b>kElements</b> - number of elements stored by the fragment
|
||||
- <b>clear()</b> - overwrites the fragment storage with zeros
|
||||
- <b>Element & operator[](int i)</b> - by-reference access of the ith element
|
||||
- <b>Element const & operator[](int i) const</b> - const by-reference access of the ith element
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_iterator_concept Fragment Iterator Concept
|
||||
@{
|
||||
|
||||
\ref fragment_iterator_concept provides structured access to the elements within a fragment with an
|
||||
optional bitcast to the desired access type
|
||||
|
||||
@par \ref fragment_iterator_concept
|
||||
Types satisfying \ref fragment_iterator_concept define the following members
|
||||
- <b>AccessType& operator[](int i)</b> - provides access to the ith element of the fragment
|
||||
- <b>AccessType& at(int d, int h, int w, int c)</b> - applies \ref layout_concept to fragment and
|
||||
provides access to element at (d, h, w, c)
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAlignment_>
|
||||
struct StorageType {
|
||||
typedef uint64_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<4> {
|
||||
typedef uint32_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<2> {
|
||||
typedef uint16_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<1> {
|
||||
typedef uint8_t Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_concept
|
||||
* @concept{fragment_concept}
|
||||
*/
|
||||
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
|
||||
struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
/// Make sure the alignment makes sense wrt the size of elements.
|
||||
static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
|
||||
/// Alignment must be a power of two
|
||||
static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
|
||||
|
||||
/// This class.
|
||||
typedef Fragment<Element_, kElements_> This_;
|
||||
/// The element.
|
||||
typedef Element_ Element;
|
||||
/// The number of elements.
|
||||
static int const kElements = kElements_;
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_DEVICE void clear() {
|
||||
// Avoid element-wise access for sub 32b element type
|
||||
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
|
||||
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
|
||||
ptr[i] = uint64_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
|
||||
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
|
||||
ptr[i] = uint32_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
|
||||
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
|
||||
ptr[i] = uint16_t(0);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < kElements; ++i) {
|
||||
storage[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element& operator[](int i) {
|
||||
assert(i < kElements_);
|
||||
return reinterpret_cast<Element*>(storage)[i];
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element const& operator[](int i) const {
|
||||
assert(i < kElements_);
|
||||
return reinterpret_cast<Element const*>(storage)[i];
|
||||
}
|
||||
|
||||
private:
|
||||
/// Storage type to use for Elements
|
||||
typedef typename StorageType<kAlignment_>::Type StorageType;
|
||||
|
||||
/// Number of elements in the storage
|
||||
static int const kStorageCount =
|
||||
(sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
|
||||
/// The storage.
|
||||
StorageType storage[kStorageCount];
|
||||
|
||||
/// Ensure that there's enough storage for all elements
|
||||
static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_iterator_concept
|
||||
* @concept{fragment_iterator_concept}
|
||||
*/
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& operator[](int i) {
|
||||
return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentConstIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
/// Create from non-constant FragmentIterator
|
||||
CUTLASS_DEVICE FragmentConstIterator(
|
||||
FragmentIterator<Fragment_, Iterations_, AccessType_> const& rhs_)
|
||||
: pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element const* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,135 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines accessors for loading and storing fragments to memory efficiently.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
value.load(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
Load<Scalar_, kAccessSize, Memory_>::load(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
value.store(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
Store<Scalar_, kAccessSize, Memory_>::store(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} /// namespace cutlass
|
||||
@ -1,149 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines multiply-add operations on fragments within a thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
struct FragmentMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef Scalar_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef Scalar_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef Scalar_ ScalarC;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(Scalar_ a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0] + c[j];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
template <>
|
||||
struct FragmentMultiplyAdd<half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The input.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(half a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The inputs.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,57 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for efficiently clearing accumulator tiles.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_ = 1>
|
||||
struct ClearAccumulators {
|
||||
/// The shared storage.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
|
||||
|
||||
/// Clear the fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void clear(Fragment_& fragment) {
|
||||
fragment.clear();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,127 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural traits of double-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct DgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
double,
|
||||
/// The scalar type for B.
|
||||
double,
|
||||
/// The scalar type for C.
|
||||
double,
|
||||
/// The scalar type for D.
|
||||
double,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, double, double, double>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
2,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
2,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
2,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 64, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<double>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// The number of doubles loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of doubles loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The DGEMM config.
|
||||
typename GemmConfig_ =
|
||||
DgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct DgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,344 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm_>
|
||||
__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, shared_storage);
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Index_ = int>
|
||||
struct GemmDesc {
|
||||
/// The dimensions of the GEMM.
|
||||
Index_ m, n, k;
|
||||
/// The alpha/beta scaling values.
|
||||
Scalar_ alpha, beta;
|
||||
/// The source matrix A.
|
||||
void const* d_a;
|
||||
/// The stride for A.
|
||||
Index_ lda;
|
||||
/// The source matrix B.
|
||||
void const* d_b;
|
||||
/// The stride for B.
|
||||
Index_ ldb;
|
||||
/// The source matrix C.
|
||||
void const* d_c;
|
||||
/// The stride for C.
|
||||
Index_ ldc;
|
||||
/// The destination matrix D.
|
||||
void* d_d;
|
||||
/// The stride for D.
|
||||
Index_ ldd;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct Gemm {
|
||||
/// This class.
|
||||
typedef Gemm<GemmTraits_> This_;
|
||||
/// The traits.
|
||||
typedef GemmTraits_ Traits;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar in the epilogue.
|
||||
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::Epilogue::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::Epilogue::ScalarD ScalarD;
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
/// The params.
|
||||
struct Params : public Traits::Params {
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
GemmDesc<ScalarEpilogue, Index> desc;
|
||||
desc.m = m;
|
||||
desc.n = n;
|
||||
desc.k = k;
|
||||
desc.alpha = alpha;
|
||||
desc.beta = beta;
|
||||
desc.d_a = reinterpret_cast<void const*>(d_a);
|
||||
desc.lda = lda;
|
||||
desc.d_b = reinterpret_cast<void const*>(d_b);
|
||||
desc.ldb = ldb;
|
||||
desc.d_c = reinterpret_cast<void const*>(d_c);
|
||||
desc.ldc = ldc;
|
||||
desc.d_d = reinterpret_cast<void*>(d_d);
|
||||
desc.ldd = ldd;
|
||||
return Traits::Params::initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void const* params_ = reinterpret_cast<void const*>(¶ms);
|
||||
|
||||
return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>),
|
||||
grid,
|
||||
block,
|
||||
const_cast<void**>(¶ms_),
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(CUfunction kernel,
|
||||
Params const& params,
|
||||
CUstream stream = CU_STREAM_LEGACY) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(¶ms))};
|
||||
|
||||
// return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>), grid, block,
|
||||
// const_cast<void**>(¶ms_), 0, stream);
|
||||
CUresult result = cuLaunchKernel(
|
||||
kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
|
||||
|
||||
if (result != CUDA_SUCCESS) {
|
||||
return cudaErrorLaunchFailure;
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Consume a single iteration of the loop.
|
||||
template <bool kIsLastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
|
||||
typename Traits::SharedLoadStream& shared_load_stream,
|
||||
typename Traits::MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
// If that's the last "load iteration" update the predicates.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.move_to_residue<false>(outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.copy();
|
||||
}
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.commit();
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Trigger the loads for the next iteration (if needed).
|
||||
if (!kIsLastIteration) {
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kUnrollingSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
|
||||
shared_load_stream.fragment_b(kUnrollingSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
dim3 block = block_swizzle.swizzle();
|
||||
|
||||
// Scale the id.
|
||||
block.x *= Traits::OutputTile::kW;
|
||||
block.y *= Traits::OutputTile::kH;
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_stream(params, shared_storage, block);
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear(shared_storage.main_loop.clear);
|
||||
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(Traits::OutputTile::kD);
|
||||
|
||||
// If we do not have enough steps in the main loop, trigger the residue code.
|
||||
global_stream.move_to_residue<true>(params.k);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the GEMM-K dimension. It may have no impact.
|
||||
global_stream.rollback();
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedLoadStream shared_load_stream(params, shared_storage);
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename Traits::MultiplyAdd::Accumulators accumulators;
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// The loop index.
|
||||
Index outer_k = params.k - kUnroll;
|
||||
|
||||
// Enter the main loop and iterate.
|
||||
for (; outer_k > 0; outer_k -= kUnroll) {
|
||||
consume_tile<false>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Residual loop.
|
||||
for (; outer_k > -kUnroll; outer_k -= kUnroll) {
|
||||
consume_tile<true>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Epilogue.
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n);
|
||||
epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators);
|
||||
}
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,231 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with
|
||||
the computed matrix product.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE bool is_zero(T x) {
|
||||
return x == T(0);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct GemmEpilogue {
|
||||
/// The traits class.
|
||||
typedef GemmEpilogueTraits_ Traits;
|
||||
/// The params.
|
||||
typedef typename Traits::Params Params;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename Traits::OutputTile OutputTile;
|
||||
/// The number of iterations.
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
/// The accumulators.
|
||||
typedef typename Traits::Accumulators Accumulators;
|
||||
/// The scalar.
|
||||
typedef typename Traits::Scalar Scalar;
|
||||
/// The functor in charge of the math.
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The iterator for C in global memory.
|
||||
typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef typename Traits::GlobalTransformerC GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef typename Traits::GlobalTransformerD GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
|
||||
/// The iterator to load D in shared memory.
|
||||
typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD;
|
||||
/// The shared load transformer for D.
|
||||
typedef Copy<typename SharedLoadIteratorD::Fragment> SharedLoadTransformerD;
|
||||
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Index m_,
|
||||
Index n_)
|
||||
: params(params_), shared_storage(shared_storage_), m(m_), n(n_) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) {
|
||||
if (is_zero(params.functor.beta)) {
|
||||
epilogue_with_or_without_beta<true>(block, accumulators);
|
||||
} else {
|
||||
epilogue_with_or_without_beta<false>(block, accumulators);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kBetaIsZero_>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
|
||||
Accumulators& accumulators) {
|
||||
|
||||
// The problem size.
|
||||
Coord<3> const bounds = cutlass::make_Coord(0, n, m);
|
||||
|
||||
// The functor.
|
||||
Functor functor(params.functor);
|
||||
// The C fragment.
|
||||
typename GlobalLoadIteratorC::Fragment fragment_c;
|
||||
// The transformed C fragment.
|
||||
typename GlobalTransformerC::OutputFragment transformed_c;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
// Compute pointer and predicate offsets for C and D global iterators.
|
||||
int const pointer_offset =
|
||||
((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.inc_advance) *
|
||||
Iterations::kW +
|
||||
params.stride_h) *
|
||||
h;
|
||||
int const predicate_offset =
|
||||
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.predicate_inc_advance) *
|
||||
Iterations::kW +
|
||||
Traits::Delta::kH) *
|
||||
h;
|
||||
|
||||
// The iterator to load the elements of the C matrix.
|
||||
GlobalLoadIteratorC global_load_iterator(
|
||||
params.iterator_c, bounds, block, pointer_offset, predicate_offset);
|
||||
// The transformer for C.
|
||||
GlobalTransformerC transformer_c;
|
||||
// The transformer for D.
|
||||
GlobalTransformerD transformer_d;
|
||||
// The iterator to store into the D matrix.
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, bounds, block, pointer_offset, predicate_offset);
|
||||
|
||||
// The transformer to transform before storing to shared memory.
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
|
||||
// The iterator to store to shared memory.
|
||||
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
|
||||
shared_storage.shared_stream.store);
|
||||
|
||||
// The iterator to load from shared memory. TODO: Use a stream.
|
||||
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
|
||||
shared_storage.shared_stream.load);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Load the C matrix into fragment.
|
||||
if (!kBetaIsZero_) {
|
||||
iterator_load(global_load_iterator, fragment_c);
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
shared_load_fence();
|
||||
|
||||
// Copy the accumulators to shared memory.
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Copy the accumulators back to registers from shared memory.
|
||||
typename SharedLoadIteratorD::Fragment fetched_d;
|
||||
shared_iterator_load(shared_load_iterator, fetched_d);
|
||||
|
||||
// Do the math.
|
||||
typename GlobalTransformerD::InputFragment fragment_d;
|
||||
|
||||
if (kBetaIsZero_) {
|
||||
functor.evaluate(fetched_d, fragment_d);
|
||||
} else {
|
||||
// Transform C fragment.
|
||||
transformer_c.transform(fragment_c, transformed_c);
|
||||
// Do the math.
|
||||
functor.evaluate(fetched_d, transformed_c, fragment_d);
|
||||
}
|
||||
|
||||
// Transform D fragment.
|
||||
typename GlobalTransformerD::OutputFragment transformed_d;
|
||||
transformer_d.transform(fragment_d, transformed_d);
|
||||
|
||||
// Copy the results to global memory.
|
||||
iterator_store(global_store_iterator, transformed_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,331 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of the GEMM epilogue.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The accumulators.
|
||||
typename Accumulators_,
|
||||
/// The iterator to load C from global memory.
|
||||
typename GlobalLoadIteratorC_,
|
||||
/// The transformer for C.
|
||||
typename GlobalTransformerC_,
|
||||
/// The transformer for D.
|
||||
typename GlobalTransformerD_,
|
||||
/// The iterator to store D to global memory.
|
||||
typename GlobalStoreIteratorD_,
|
||||
/// The iterator to store D to shared memory.
|
||||
typename SharedStoreIteratorD_,
|
||||
/// The shared store transformer for D.
|
||||
typename SharedStoreTransformerD_,
|
||||
/// The iterator to load D from shared memory.
|
||||
typename SharedLoadIteratorD_,
|
||||
/// The number of iterations in the epilogue.
|
||||
typename Iterations_,
|
||||
/// The iterations strides.
|
||||
typename Delta_,
|
||||
/// The functor to be used in the epilogue.
|
||||
typename Functor_,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct GemmEpilogueTraits {
|
||||
//
|
||||
/// The output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The number of iterations.
|
||||
/// The accumulators.
|
||||
typedef Accumulators_ Accumulators;
|
||||
/// The iterator for C in global memory.
|
||||
typedef GlobalLoadIteratorC_ GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef GlobalTransformerC_ GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef GlobalTransformerD_ GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef GlobalStoreIteratorD_ GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedLoadIteratorD_ SharedLoadIteratorD;
|
||||
/// typedef typename GemmConfig::EpilogueIterations Iterations;
|
||||
typedef Iterations_ Iterations;
|
||||
/// The iterations strides.
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// The functor in charge of the math.
|
||||
typedef Functor_ Functor;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The scalar.
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The strides for H and W in the different iterations of the epilogue.
|
||||
Index stride_h, stride_w;
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadIteratorC::Params iterator_c;
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreIteratorD::Params iterator_d;
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreIteratorD::Params shared_store_iterator_d;
|
||||
/// The params for the D shared load iterator.
|
||||
typename SharedLoadIteratorD::Params shared_load_iterator_d;
|
||||
/// The functor params.
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// At the end of the H iteration, we jump over a number of columns.
|
||||
this->stride_h = desc.ldd * Delta::kH;
|
||||
// Nothing to do here.
|
||||
this->stride_w = 0;
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = iterator_c.initialize(
|
||||
reinterpret_cast<ScalarC const*>(desc.d_c), desc.ldc, desc.n, stride_w, Delta::kW);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return iterator_d.initialize(
|
||||
reinterpret_cast<ScalarD*>(desc.d_d), desc.ldd, desc.n, stride_w, Delta::kW);
|
||||
}
|
||||
};
|
||||
|
||||
/// The shared memory storage to exchange data.
|
||||
union StreamSharedStorage {
|
||||
// The storage for the store iterator.
|
||||
typename SharedStoreIteratorD::SharedStorage store;
|
||||
// The storage for the store iterator.
|
||||
typename SharedLoadIteratorD::SharedStorage load;
|
||||
};
|
||||
|
||||
/// The shared memory to swizzle the data in the epilogue.
|
||||
struct SharedStorage {
|
||||
// The storage for the shared stream D.
|
||||
StreamSharedStorage shared_stream;
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct GemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of iterations in the epilogue.
|
||||
typedef Shape<1,
|
||||
GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
|
||||
GemmConfig_::kAccumulatorsPerLdsB,
|
||||
GemmConfig_::kAccumulatorsPerLdsB>
|
||||
Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0,
|
||||
GemmConfig_::kAccumulatorsPerLdsB*(
|
||||
GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
|
||||
0>
|
||||
Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef GemmSharedStoreTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsD,
|
||||
// The skew -- 128 / sizeof(ScalarD) / kScalarsPerStsD is the number of threads involved in
|
||||
// a single STS. We divide by 2 as our objective is to add a skew to the odd threads to
|
||||
// avoid bank conflicts between odd and even threads.
|
||||
128 / sizeof(typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
|
||||
GemmConfig_::kScalarsPerStsD>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef GemmSharedLoadTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of columns of the output tile written by iteration.
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM config.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper to create the traits class.
|
||||
typename Helper_ = GemmEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename GemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The number of iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,182 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing
|
||||
to shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/iterator_access.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_>
|
||||
|
||||
struct GlobalLoadStreamBase {
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
/// The store iterator to write to shared memory.
|
||||
typedef StoreIterator_ StoreIterator;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
|
||||
/// The scalar type of the iterator.
|
||||
typedef typename LoadIterator::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
// The store iterator.
|
||||
typename StoreIterator::Params store_iterator;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
|
||||
int error_code = load_iterator.initialize(desc, pointer, ld);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
return store_iterator.initialize();
|
||||
}
|
||||
};
|
||||
|
||||
/// The amount of storage in shared memory needed to store the tile.
|
||||
typedef typename StoreIterator::SharedStorage SharedStoreStorage;
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
union SharedStorage {
|
||||
// The load iterator.
|
||||
typename LoadIterator::SharedStorage load_iterator;
|
||||
// The store iterator.
|
||||
SharedStoreStorage store_iterator;
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& block)
|
||||
: load_iterator(params.load_iterator, bounds, block),
|
||||
transformer(),
|
||||
store_iterator(params.store_iterator, shared_storage.store_iterator)
|
||||
|
||||
{
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
iterator_store(store_iterator, transformed_fragment);
|
||||
store_iterator.inc_stage();
|
||||
}
|
||||
|
||||
/// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
load_iterator.residue(k);
|
||||
if (!skip_clear) {
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the GEMM-k dimension.
|
||||
CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
|
||||
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
/// The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
/// The store iterator.
|
||||
StoreIterator store_iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_ = Copy<typename LoadIterator_::Fragment> >
|
||||
|
||||
struct GlobalLoadStream : public GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> {
|
||||
/// The base class.
|
||||
typedef GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params,
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
Coord<3> const& bounds,
|
||||
Coord<3> const& block)
|
||||
: Base(params, shared_storage, bounds, block) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,541 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing to global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functor reshapes a tile of threads to match a tile of data. The idea is that when
|
||||
// the user wants to build the iterator traits, he/she may want to specify the tile independently
|
||||
// from the number of scalars loaded/stored per instruction. For example, in the row-major version
|
||||
// with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
|
||||
// each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
|
||||
// threads has to change. The code below detects that and correct the code automatically - it is
|
||||
// a helper when the user does not specify the right configuration.
|
||||
|
||||
template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
|
||||
struct ReshapeThreads {
|
||||
typedef Threads_ Threads;
|
||||
};
|
||||
|
||||
template <typename Tile_, typename Threads_>
|
||||
struct ReshapeThreads<Tile_, Threads_, true> {
|
||||
typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct GemmGlobalTileTraits {
|
||||
/// Identity of the operand
|
||||
static GemmOperand::Kind const kOperand = kOperand_;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kAccessSize_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
|
||||
|
||||
/// The tile shape
|
||||
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile Tile;
|
||||
/// The threads shape
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The relative offset between two elements in the H/W dimension in adjacent threads.
|
||||
typedef Shape<1, 1, Tile::kC> ThreadsDelta;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize>
|
||||
Iterations;
|
||||
|
||||
typedef GemmMultiplicandTraits<Tile, kOperand, kLayout> MultiplicandTraits;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
|
||||
struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// The stride in the H dimension.
|
||||
static int const kStrideH = kStrideH_;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
typedef typename Base::Threads Threads;
|
||||
|
||||
typedef typename Base::ThreadsDelta ThreadsDelta;
|
||||
|
||||
typedef typename Base::ImmediateOffsetStrides ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorAb
|
||||
: public TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> This_; /// The base class.
|
||||
|
||||
typedef TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
/// Fragment type loaded by the iterator
|
||||
typedef typename Base::Fragment Fragment;
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
typedef cutlass::PredicateVector<ShapeCount<typename Base::Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Iterator parameters type
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
|
||||
Index inc_d = 0;
|
||||
Index inc_advance = 0;
|
||||
// Move by some columns for each iteration in the H dimension.
|
||||
Index inc_h = Base::Delta::kH * stride_h;
|
||||
|
||||
// Move by some more columns in the number of iterations if the D dimension is > 1.
|
||||
if (Base::Delta::kD > 0) {
|
||||
inc_d = Base::Delta::kD * stride_h - (Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// Move to the beginning of the next iteration.
|
||||
if (kAdvance == IteratorAdvance::kH && Base::Delta::kD > 0) {
|
||||
inc_advance = inc_d;
|
||||
} else if (kAdvance == IteratorAdvance::kH) {
|
||||
inc_advance = inc_h;
|
||||
} else if (Base::Delta::kD > 0) {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h -
|
||||
(Base::Iterations::kD - 1) * Base::Delta::kD * stride_h;
|
||||
} else {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// Move to the residue.
|
||||
Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
// The jump in the gemm-k dimension.
|
||||
Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
|
||||
|
||||
// Compute the offset to the residue and how to "come" back.
|
||||
Index const kResidue = desc.k % kBlock;
|
||||
if (kResidue > 0) {
|
||||
move_to_residue_offset = (desc.k - kResidue) * stride;
|
||||
} else {
|
||||
move_to_residue_offset = (desc.k - kBlock) * stride;
|
||||
}
|
||||
|
||||
Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The extra offset to control moving to the residue.
|
||||
Index move_to_residue_offset;
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// The column.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Add the blocks indices.
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_h += block[1];
|
||||
block_w += block[2];
|
||||
|
||||
} else {
|
||||
block_h += block[2];
|
||||
block_w += block[1];
|
||||
}
|
||||
|
||||
// Setup the pointer.
|
||||
params.pointer += (block_h * params.stride_h + block_w);
|
||||
|
||||
// Initialize predicates
|
||||
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Initialize the predicates.
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
|
||||
// Setup the masks to control loads.
|
||||
predicates.fill(0);
|
||||
|
||||
int bounds_h, bounds_w;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
bounds_w = bounds[2] - block[2];
|
||||
bounds_h = bounds[1];
|
||||
|
||||
} else {
|
||||
bounds_w = bounds[1];
|
||||
bounds_h = bounds[2] - block[1];
|
||||
}
|
||||
|
||||
// Fill in the bits of the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
bool flag = w * Base::Delta::kW < bounds_w;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
|
||||
} else {
|
||||
flag = flag && (h * Base::Delta::kH) < bounds_h;
|
||||
}
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
predicates.set(bit, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
// Store the pointer and the predicates.
|
||||
stored_pointer = params.pointer;
|
||||
stored_predicates = predicates;
|
||||
|
||||
// Move the pointer to the residue.
|
||||
params.pointer += params.move_to_residue_offset;
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// The unrolling factor.
|
||||
int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
|
||||
// Clear the predicates for the residue. TODO: We can do something smarter.
|
||||
int const kResidue = (int)(k % (Index)kUnroll);
|
||||
if (kResidue > 0) {
|
||||
residue(kResidue);
|
||||
}
|
||||
}
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_DEVICE void residue(Index k) {
|
||||
// The coordinates of the thread.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Update the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
Index offset = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
|
||||
} else {
|
||||
offset += block_w + w * Base::Delta::kW;
|
||||
}
|
||||
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
if (offset >= k) {
|
||||
predicates.set(bit, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
params.pointer = stored_pointer;
|
||||
predicates = stored_predicates;
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
return predicates[bit];
|
||||
}
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The pointer.
|
||||
typename Base::Scalar const* stored_pointer;
|
||||
/// The predicates.
|
||||
PredicateVector predicates, stored_predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_advance, inc_h;
|
||||
/// The strides to increment the predicate offset
|
||||
Index predicate_inc_advance, predicate_inc_h;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Each column of the matrix.
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ld;
|
||||
// Each thread output 1 column per iteration. The stride between columns is given by the
|
||||
// number of scalars that are loaded per LDS for B.
|
||||
inc_h = ld * TileTraits_::kStrideH;
|
||||
inc_advance =
|
||||
(ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
|
||||
|
||||
predicate_offset = bound;
|
||||
predicate_inc_h = TileTraits_::kStrideH;
|
||||
predicate_inc_advance =
|
||||
-((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int offset = 0,
|
||||
int pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
this->params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the validity of the iterator.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,141 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear
|
||||
memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to describe attributes of GEMM matrix operands
|
||||
template <GemmOperand::Kind kOperand_, MatrixLayout::Kind kLayout_>
|
||||
struct GemmOperandTraitsAb {
|
||||
static const bool Congruous =
|
||||
(kOperand_ == GemmOperand::kA ^ kLayout_ == MatrixLayout::kRowMajor);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmOperand::Kind kOperand_, typename Tile_>
|
||||
struct GetExtent;
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kA, Tile_> {
|
||||
static const int kExtent = Tile_::kW;
|
||||
};
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kB, Tile_> {
|
||||
static const int kExtent = Tile_::kH;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Determines the shape of a multiplicand tile in terms of strided (H) and contiguous (W)
|
||||
/// dimensions
|
||||
template <typename ThreadBlockTile_, GemmOperand::Kind Usage, MatrixLayout::Kind Layout>
|
||||
struct GemmMultiplicandTraits {
|
||||
// Only defined for A or B
|
||||
static_assert(Usage == GemmOperand::kA || Usage == GemmOperand::kB,
|
||||
"MultiplicandTileShape defined only for A or B operands.");
|
||||
|
||||
/// Shape of GEMM thread block tile (K, N, M)
|
||||
typedef ThreadBlockTile_ ThreadBlockTile;
|
||||
|
||||
/// Identifies multiplicand
|
||||
static GemmOperand::Kind const kUsage = Usage;
|
||||
|
||||
/// Layout of tile
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
|
||||
// True if K is the strided dimension
|
||||
static bool const kKstrided = (kUsage == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor);
|
||||
|
||||
/// Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand
|
||||
typedef typename platform::conditional<
|
||||
kKstrided,
|
||||
Shape<1, ThreadBlockTile::kD, GetExtent<Usage, ThreadBlockTile>::kExtent>,
|
||||
Shape<1, GetExtent<Usage, ThreadBlockTile>::kExtent, ThreadBlockTile::kD> >::type Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Project's a coordinate (K, N, M) onto inner and outer dimensions defined for each
|
||||
/// operand.
|
||||
template <GemmOperand::Kind operand, bool Kstrided = true>
|
||||
struct ProjectOperand;
|
||||
|
||||
/// Project A operand - (0, K, M)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kA, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[2]);
|
||||
} else {
|
||||
return make_Coord(0, coord[2], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project B operand - (0, K, N)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kB, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[1]);
|
||||
} else {
|
||||
return make_Coord(0, coord[1], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project C operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kC, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
/// Project D operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kD, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,113 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for managing loading and storing fragments to shared memory in the
|
||||
efficient GEMM pipeline.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
|
||||
struct SharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ Iterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The iterator params.
|
||||
typename Iterator::Params iterator;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize() { return iterator.initialize(); }
|
||||
};
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
typedef typename Iterator::Storage SharedStorage;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
this->initialize(params, shared_storage);
|
||||
}
|
||||
|
||||
/// Initialize the stream.
|
||||
CUTLASS_DEVICE void initialize(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
// The iterator.
|
||||
iterator = Iterator(params.iterator, shared_storage);
|
||||
// The transformer.
|
||||
transformer = Transformer();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(FetchedFragment &fetched) { shared_iterator_load(iterator, fetched); }
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int d, FetchedFragment &fetched) {
|
||||
shared_iterator_load(iterator, fetched, d);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(FetchedFragment &fetched, TransformedFragment &transformed) {
|
||||
transformer.transform(fetched, transformed);
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() { iterator.inc_stage(); }
|
||||
|
||||
/// The iterator.
|
||||
Iterator iterator;
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,417 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing tiles to and from shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_>
|
||||
struct GemmSharedStoreTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Tile::kC, kScalarsPerSts_> ThreadsStrides;
|
||||
/// The skew.
|
||||
static int const kSkew = 0;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1,
|
||||
Tile::kH / Threads::kH,
|
||||
Tile::kW / Threads::kW,
|
||||
Tile::kC / Threads::kC / kAccessSize>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize>
|
||||
ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_, int kSkew_>
|
||||
struct GemmSharedStoreWithSkewTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skews.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Shape<Tile_::kD, Tile_::kH, Tile_::kW + kSkew_>,
|
||||
kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The number of scalars per STS.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, kScalarsPerSts_, ShapeCount<Tile>::kHwc / Threads::kW> ThreadsStrides;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileATraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
// static int const kScalarsPerLds = kScalarsPerLds_;
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_row = (threadIdx.x & 0x0e) / 2;
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate vector.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileBTraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// The warp in the slice.
|
||||
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_col = warp_in_slice / Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kScalarsPerSts_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedStoreTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, kScalarsPerThread / kAccessSize> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// The warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
|
||||
// The position of the warp in the 2D tile.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
int const warp_col = warp / Warps::kW;
|
||||
|
||||
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
|
||||
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
|
||||
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
|
||||
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
|
||||
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
|
||||
|
||||
// Odd threads go to the second half of shared memory.
|
||||
int const row = threadIdx.x & 0x01;
|
||||
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
|
||||
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
|
||||
// Embed the offset in a 4D coords.
|
||||
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kTileH_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
|
||||
/// conflicts in the epilogue.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
|
||||
// Compute the number of iterations per warp in the Tile::kH dimension.
|
||||
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
|
||||
|
||||
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
|
||||
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
|
||||
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
|
||||
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
|
||||
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
|
||||
// to keep the number of elements to reduce for split-K.
|
||||
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
|
||||
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
|
||||
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
|
||||
|
||||
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
|
||||
// threads to grab the other elements.
|
||||
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
|
||||
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
|
||||
ImmediateOffsetStrides;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Each warp works on a different column.
|
||||
int const h = threadIdx.x / kWarpSize;
|
||||
// Compute the row.
|
||||
int const w = (threadIdx.x & (kWarpSize - 1)) * kAccessSize;
|
||||
int offset = 0;
|
||||
if (Iterations::kH == 1) {
|
||||
int const row = h & 0x1;
|
||||
int const col = h / 2;
|
||||
offset = row * ShapeCount<Tile>::kWc + col * OutputTile::kW * Iterations::kD + w;
|
||||
} else {
|
||||
offset = h * OutputTile::kW * Iterations::kD + w;
|
||||
}
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,818 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of complete GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/clear_accumulators.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/identity_block_swizzle.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The scalar type for A.
|
||||
typename ScalarA_,
|
||||
/// The scalar type for B.
|
||||
typename ScalarB_,
|
||||
/// The scalar type for C.
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D.
|
||||
typename ScalarD_,
|
||||
/// The output tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math.
|
||||
typename MultiplyAdd_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
int kScalarsPerStsA_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
int kScalarsPerStsB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int kScalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_,
|
||||
/// Do we do the residue in the prologue?
|
||||
bool kResidueInPrologue_ = false>
|
||||
|
||||
struct GemmConfig {
|
||||
//
|
||||
/// The scalar for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// The tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The functor to do D = A*B + C.
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
/// The shape of the instruction.
|
||||
typedef typename MultiplyAdd::InstructionShape InstructionShape;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
|
||||
/// The accumulators.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// The number of warps.
|
||||
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
|
||||
/// The default warp size (32 threads per warp).
|
||||
static int const kWarpSize = cutlass::kWarpSize;
|
||||
/// The numnber of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
|
||||
static int const kScalarsPerStsA = kScalarsPerStsA_;
|
||||
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
|
||||
static int const kScalarsPerStsB = kScalarsPerStsB_;
|
||||
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
|
||||
|
||||
/// The number of scalars per LDG for C.
|
||||
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
|
||||
|
||||
/// The number of scalars per STS/LDS/STG for D.
|
||||
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
|
||||
static int const kScalarsPerStsD = kScalarsPerStsD_;
|
||||
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
|
||||
|
||||
/// The number of accumulators that are going to be fed from one LDS A/B.
|
||||
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
|
||||
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// Do we do the residue in the prologue?
|
||||
static bool const kResidueInPrologue = kResidueInPrologue_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for A.
|
||||
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^T.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^T.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_, bool kResidueInPrologue_ = GemmTraits_::kResidueInPrologue>
|
||||
struct GemmResidue {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have
|
||||
// complete main loops after that. It helps simplify the logic in the main loop.
|
||||
if (kIsPrologue) {
|
||||
stream_a.move_to_residue(k);
|
||||
stream_b.move_to_residue(k);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct GemmResidue<GemmTraits_, false> {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The index.
|
||||
typedef typename GemmTraits_::Index Index;
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(GemmTraits_::OutputTile::kD);
|
||||
|
||||
// Call the residue code. That's the same path as CUTLASS 1.0.0.
|
||||
if (kIsPrologue && k < kUnroll) {
|
||||
stream_a.residue(k, true);
|
||||
stream_b.residue(k, true);
|
||||
} else if (k <= kUnroll) {
|
||||
stream_a.residue(k, false);
|
||||
stream_b.residue(k, false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM configuration.
|
||||
typename GemmConfig_,
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typename GlobalLoadStreamA_,
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typename GlobalLoadStreamB_,
|
||||
/// The stream to load A from shared memory.
|
||||
typename SharedLoadStreamA_,
|
||||
/// The stream to load B from shared memory.
|
||||
typename SharedLoadStreamB_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = IdentityBlockSwizzle,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The tool used to clear accumulators.
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Scalar> >
|
||||
|
||||
struct GemmTraits {
|
||||
/// This class.
|
||||
typedef GemmTraits<GemmConfig_,
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_>
|
||||
This_;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig::OutputTile OutputTile;
|
||||
/// Is the residue treated in the prologue?
|
||||
static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue;
|
||||
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
|
||||
/// The layout of A.
|
||||
static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
|
||||
/// The scalar for A.
|
||||
typedef typename GlobalLoadStreamA_::Scalar ScalarA;
|
||||
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStreamB_ GlobalLoadStreamB;
|
||||
/// The layout of B.
|
||||
static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
|
||||
/// The scalar for B.
|
||||
typedef typename GlobalLoadStreamB_::Scalar ScalarB;
|
||||
|
||||
/// The iterator for A to load from shared memory.
|
||||
typedef SharedLoadStreamA_ SharedLoadStreamA;
|
||||
/// The iterator for B to load from shared memory.
|
||||
typedef SharedLoadStreamB_ SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The epilogue.
|
||||
typedef Epilogue_ Epilogue;
|
||||
/// The scalars in the epilogue.
|
||||
typedef typename Epilogue::ScalarC ScalarC;
|
||||
typedef typename Epilogue::ScalarD ScalarD;
|
||||
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// Clear the accumulators.
|
||||
typedef ClearAccumulators_ ClearAccumulators;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n, k;
|
||||
/// The params for the A stream.
|
||||
typename GlobalLoadStreamA::Params global_stream_a;
|
||||
/// The params for the B stream.
|
||||
typename GlobalLoadStreamB::Params global_stream_b;
|
||||
/// The params for the A stream from shared memory.
|
||||
typename SharedLoadStreamA::Params shared_stream_a;
|
||||
/// The params for the B stream from shared memory.
|
||||
typename SharedLoadStreamB::Params shared_stream_b;
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// Set the problem size.
|
||||
this->m = desc.m;
|
||||
this->n = desc.n;
|
||||
this->k = desc.k;
|
||||
|
||||
// Initialize the iterator for A.
|
||||
int error_code =
|
||||
global_stream_a.initialize(desc, reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Initialize the iterator for B.
|
||||
error_code =
|
||||
global_stream_b.initialize(desc, reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for A.
|
||||
template <typename GlobalLoadStream_, typename SharedLoadStream_>
|
||||
union StreamSharedStorage {
|
||||
// The storage needed by the global stream.
|
||||
typename GlobalLoadStream_::SharedStorage global;
|
||||
// The storage needed by the shared stream.
|
||||
typename SharedLoadStream_::SharedStorage shared;
|
||||
};
|
||||
|
||||
// The storage for the main loop + prologue.
|
||||
struct MainLoopSharedStorage {
|
||||
// The storage to shuffle the A matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamA, SharedLoadStreamA> stream_a;
|
||||
// The storage to shuffle the B matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamB, SharedLoadStreamB> stream_b;
|
||||
// The storage to clear the accumulators if needed.
|
||||
typename ClearAccumulators::SharedStorage clear;
|
||||
};
|
||||
|
||||
/// The storage in shared memory.
|
||||
union SharedStorage {
|
||||
// The storage for the main loop.
|
||||
MainLoopSharedStorage main_loop;
|
||||
// The storage for the epilogue.
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// Assemble the global load streams for A/B.
|
||||
struct GlobalLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
dim3 const& block)
|
||||
: stream_a(params.global_stream_a,
|
||||
shared_storage.main_loop.stream_a.global,
|
||||
cutlass::make_Coord(0, params.k, params.m),
|
||||
cutlass::make_Coord(0, 0, block.x)),
|
||||
stream_b(params.global_stream_b,
|
||||
shared_storage.main_loop.stream_b.global,
|
||||
cutlass::make_Coord(0, params.k, params.n),
|
||||
make_Coord(0, 0, block.y)) {}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
stream_a.copy();
|
||||
stream_b.copy();
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
stream_b.commit();
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
GemmResidue<This_>::move_to_residue<kIsPrologue>(stream_a, stream_b, k);
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() { GemmResidue<This_>::rollback(stream_a, stream_b); }
|
||||
|
||||
/// The stream for A.
|
||||
GlobalLoadStreamA stream_a;
|
||||
/// The stream for B.
|
||||
GlobalLoadStreamB stream_b;
|
||||
};
|
||||
|
||||
/// Assemble the shared load stream for A/B.
|
||||
struct SharedLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const& params, SharedStorage& shared_storage) {
|
||||
stream_a.initialize(params.shared_stream_a, shared_storage.main_loop.stream_a.shared);
|
||||
stream_b.initialize(params.shared_stream_b, shared_storage.main_loop.stream_b.shared);
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
stream_a.copy(step, fetched_a[step % 2]);
|
||||
stream_b.copy(step, fetched_b[step % 2]);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
stream_a.commit(fetched_a[step % 2], transformed_a[step % 2]);
|
||||
stream_b.commit(fetched_b[step % 2], transformed_b[step % 2]);
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamA::Fragment const& fragment_a(int step) const {
|
||||
return transformed_a[step % 2];
|
||||
}
|
||||
|
||||
/// The fragment B.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamB::Fragment const& fragment_b(int step) const {
|
||||
return transformed_b[step % 2];
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
stream_a.inc_stage();
|
||||
stream_b.inc_stage();
|
||||
}
|
||||
|
||||
/// The stream for A.
|
||||
SharedLoadStreamA stream_a;
|
||||
/// The fragments to fetch A.
|
||||
typename SharedLoadStreamA::FetchedFragment fetched_a[2];
|
||||
/// The fragments to transform A.
|
||||
typename SharedLoadStreamA::TransformedFragment transformed_a[2];
|
||||
/// The stream for B.
|
||||
SharedLoadStreamB stream_b;
|
||||
/// The fragments to fetch B.
|
||||
typename SharedLoadStreamB::FetchedFragment fetched_b[2];
|
||||
/// The fragments to transform B.
|
||||
typename SharedLoadStreamB::TransformedFragment transformed_b[2];
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
|
||||
SharedLoadStreamB::Iterator::kRequiresLoadFence) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
|
||||
struct SimplifiedGemmTraitsHelper {
|
||||
/// The global iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The data converter for A before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The global iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
/// The data converter for B before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The config for the GEMM.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
// The configuration for the A matrix.
|
||||
typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
|
||||
// The configuration for the B matrix.
|
||||
typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
|
||||
// The helper class to create the streams and iterators.
|
||||
typename Helper_ =
|
||||
SimplifiedGemmTraitsHelper<GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_> >
|
||||
struct SimplifiedGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
Epilogue_,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,90 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tile traits used to construct global tile iterator for HGEMM. This is intended to
|
||||
partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate
|
||||
memory accesses larger than 16 bits.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 2, Base::Tile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 2, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 2,
|
||||
2,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,104 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Specialization implementing multiply-add operation on half-precision floating point
|
||||
fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<half, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Make sure there's an even number of elements in both dimensions.
|
||||
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// The inputs.
|
||||
__half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
|
||||
// The offsets in the output fragment.
|
||||
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
|
||||
// Compute the product a[i] * b[j].H0_H0.
|
||||
d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
|
||||
// Compute the product a[i] * b[j].H1_H1.
|
||||
d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,94 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in
|
||||
shared memory for multiplicands.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct HgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The input fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be half fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, half>::value), "Works on half");
|
||||
|
||||
/// The number of elements must be a multiple of 2.
|
||||
static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE HgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
// The indices to read two consecutive "rows".
|
||||
int const i0 = 2 * d + 0;
|
||||
int const i1 = 2 * d + 1;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
|
||||
int b0, b1;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
|
||||
// The indices to store with "strides".
|
||||
int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
|
||||
dst_int[j0] = b0;
|
||||
dst_int[j1] = b1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,397 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of half-precision GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/hgemm_global_tile.h>
|
||||
#include <cutlass/gemm/hgemm_multiply_add.h>
|
||||
#include <cutlass/gemm/hgemm_swizzle.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 2>
|
||||
struct HgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
half,
|
||||
/// The scalar type for D.
|
||||
half,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, half, half, half>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
2,
|
||||
/// The number of scalars per STS for D.
|
||||
8,
|
||||
/// The number of scalars per LDS for D.
|
||||
2,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K ) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32(the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct HgemmTraitsHelper {
|
||||
/// The HGEMM config.
|
||||
typedef HgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef HgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef HgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
|
||||
/// The functor to do the multiply-add in the main loop.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_> GemmEpilogueTraits;
|
||||
/// The epilogue.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<half>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = HgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
struct HgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,48 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
|
||||
|
||||
Currently, we only implement an identity mapping.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IdentityBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IdentityBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,320 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and
|
||||
floating-point output matrix formats.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmFloatToInt8Converter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<float, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<int8_t, kElements_> OutputFragment;
|
||||
|
||||
// We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
|
||||
// The outputs.
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Iterate over the floats and pack them together to produce ints.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the float4.
|
||||
float4 f4 = src_f4[i];
|
||||
|
||||
// Clamp the 4 elements of the floats to the [-128, +127] range.
|
||||
float x = fmaxf(-128.f, fminf(127.f, f4.x));
|
||||
float y = fmaxf(-128.f, fminf(127.f, f4.y));
|
||||
float z = fmaxf(-128.f, fminf(127.f, f4.z));
|
||||
float w = fmaxf(-128.f, fminf(127.f, f4.w));
|
||||
|
||||
// Convert to integers.
|
||||
int ix = (int)x;
|
||||
int iy = (int)y;
|
||||
int iz = (int)z;
|
||||
int iw = (int)w;
|
||||
|
||||
// Extract the lower bytes to build an int32 with 4 int8.
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
|
||||
|
||||
// Store the int.
|
||||
dst_int[i] = ix;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmGlobalStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
|
||||
typedef IgemmFloatToInt8Converter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmInt8ToFloatConverter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<int8_t, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<float, kElements_> OutputFragment;
|
||||
|
||||
// We are unpacking 4 int8s from int32.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
// The outputs.
|
||||
float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
|
||||
|
||||
// Iterate over the int8 and unpack them together to produce floats.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the int.
|
||||
int ix, iy, iz, iw = src_int[i];
|
||||
|
||||
// Extract the 4 bytes.
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
|
||||
|
||||
// The floats.
|
||||
float fx, fy, fz, fw;
|
||||
|
||||
// Convert to floats (make sure we generate I2F.F32.S8).
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
|
||||
|
||||
// Store the float4.
|
||||
dst_f4[i] = make_float4(fx, fy, fz, fw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputScalar_>
|
||||
struct IgemmGlobalLoadTransformer {
|
||||
typedef Convert<InputFragment_, Fragment<OutputScalar_, InputFragment_::kElements> > Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
|
||||
typedef IgemmInt8ToFloatConverter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmSharedStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
|
||||
struct IgemmEpilogueTraitsHelper
|
||||
: public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> Base;
|
||||
/// The config.
|
||||
typedef IgemmConfig_ IgemmConfig;
|
||||
|
||||
/// The scalar type of the epilogue.
|
||||
typedef typename Base::Scalar Scalar;
|
||||
/// The iterations.
|
||||
typedef typename Base::Iterations Iterations;
|
||||
/// The iterations strides.
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalLoadTileTraits GlobalLoadTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits> GlobalLoadIteratorC;
|
||||
/// The fragment that needs to be produced by the load iterator.
|
||||
typedef typename GlobalLoadIteratorC::Fragment GlobalFragmentC;
|
||||
/// The transformer from loaded data to math fragment.
|
||||
typedef
|
||||
typename IgemmGlobalLoadTransformer<GlobalFragmentC, Scalar>::Transformer GlobalTransformerC;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalStoreTileTraits GlobalStoreTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits> GlobalStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename GlobalStoreIteratorD::Fragment GlobalFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef
|
||||
typename IgemmGlobalStoreTransformer<Scalar, GlobalFragmentD>::Transformer GlobalTransformerD;
|
||||
|
||||
/// The traits class for the shared iterator to store D to shared memory.
|
||||
typedef typename Base::SharedStoreTileTraits SharedStoreTileTraits;
|
||||
/// The shared iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal>
|
||||
SharedStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename SharedStoreIteratorD::Fragment SharedStoreFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef typename IgemmSharedStoreTransformer<typename IgemmConfig::Accumulators::Element,
|
||||
SharedStoreFragmentD>::Transformer
|
||||
SharedStoreTransformerD;
|
||||
/// The traits class for the shared iterator to load D from shared memory.
|
||||
typedef typename Base::SharedLoadTileTraits SharedLoadTileTraits;
|
||||
/// The shared iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The config.
|
||||
typename IgemmConfig_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class to assemble the traits.
|
||||
typename Helper_ = IgemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct IgemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename IgemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename IgemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {
|
||||
/// Do we output in int8?
|
||||
static bool const kInt8Output =
|
||||
platform::is_same<typename IgemmConfig_::ScalarC, int8_t>::value != 0;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
|
||||
struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,161 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements tile iterators to partition the thread block tile into 2D subtiles and
|
||||
efficiently load each. Applies permute transformation to construct 'interleaved K-strided'
|
||||
data layout in which 4-element dot products from the same K index are arranged in consecutive
|
||||
locations within shared memory.
|
||||
|
||||
Supports efficient loads from shared memory to target the DP4A instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 4, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 4,
|
||||
4,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 4, Base::Tile::kC> ThreadsDelta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Deprecated. Please use IgemmGlobalTileTraits instead.
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmContiguousGlobalTileTraits
|
||||
: public IgemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
|
||||
/// The functor to compute the thread offset.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) {
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
// The residue.
|
||||
int const kResidue = (int)(bounds[1] % kBlock);
|
||||
|
||||
// Compute the number of elements that are valid.
|
||||
int const left = kResidue - Base::thread_offset[2];
|
||||
if (left > 0 && left < 4) {
|
||||
mask_ = (1u << (8 * left)) - 1u;
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::get(value, d, h, w, c);
|
||||
if (in_residue_) {
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(typename Base::Index k) {
|
||||
Base::move_to_residue(k);
|
||||
in_residue_ = true;
|
||||
}
|
||||
|
||||
/// Move back to the beginning of the first tile.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
Base::rollback();
|
||||
in_residue_ = false;
|
||||
}
|
||||
|
||||
/// Are we in the residue?
|
||||
bool in_residue_;
|
||||
/// The mask to clean up the values.
|
||||
uint32_t mask_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,89 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements matrix multiply accumulate operation of 8-bit integer data using DP4A
|
||||
instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<4, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef int8_t ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW * 4> FragmentA;
|
||||
/// The type for B.
|
||||
typedef int8_t ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH * 4> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
// The inputs.
|
||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d[j * AccumulatorsPerThread::kW + i])
|
||||
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,115 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a fragment of data containing packed 8-bit integer elements.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct IgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The source fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The destination fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be int8 fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, int8_t>::value), "Works on int8");
|
||||
|
||||
/// The number of elements must be a multiple of 4.
|
||||
static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
|
||||
"Not multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
for (int h = 0; h < FragmentShape::kH / 4; ++h) {
|
||||
for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
|
||||
int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
int a2 = src_int[i2];
|
||||
int a3 = src_int[i3];
|
||||
|
||||
int b0, b1, b2, b3, c0;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
|
||||
|
||||
dst_int[i0] = b0;
|
||||
dst_int[i1] = b1;
|
||||
dst_int[i2] = b2;
|
||||
dst_int[i3] = b3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,539 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed
|
||||
to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output
|
||||
formats vary.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/igemm_epilogue.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/gemm/igemm_multiply_add.h>
|
||||
#include <cutlass/gemm/igemm_swizzle.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
ScalarD_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
int8_t,
|
||||
/// The scalar type for D.
|
||||
int8_t,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
4,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
4,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct IgemmTraitsHelper {
|
||||
/// The IGEMM config.
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, AccumulatorsPerThread_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
|
||||
|
||||
/// The default transformer for A.
|
||||
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
|
||||
|
||||
// The default transformer for B.
|
||||
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA, Copy<typename SharedLoadIteratorA::Fragment> >
|
||||
SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB, Copy<typename SharedLoadIteratorB::Fragment> >
|
||||
SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The epilogue.
|
||||
typedef IgemmEpilogue<IgemmEpilogueTraits<GemmConfig, EpilogueFunctor_> > Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ScalarD_>
|
||||
struct IgemmEpilogueScalar {
|
||||
typedef float Scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IgemmEpilogueScalar<int> {
|
||||
typedef int Scalar;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<32, 128, 128>,
|
||||
/// The output type.
|
||||
typename ScalarD_ = int,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<typename IgemmEpilogueScalar<ScalarD_>::Scalar>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = IgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarD_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
Index_> >
|
||||
struct IgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,85 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_> >
|
||||
struct LinearScaling {
|
||||
// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The adapater.
|
||||
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
|
||||
|
||||
/// The parameters.
|
||||
struct Params {
|
||||
/// The alpha/beta scaling params.
|
||||
Scalar alpha, beta;
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
mad.multiply(alpha, accum, output);
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
FragmentB_ tmp;
|
||||
mad.multiply(beta, old, tmp);
|
||||
mad.multiply_add(alpha, accum, tmp, output);
|
||||
}
|
||||
|
||||
/// The alpha/beta scaling factors.
|
||||
Scalar alpha, beta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,127 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of single-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct SgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
float,
|
||||
/// The scalar type for B.
|
||||
float,
|
||||
/// The scalar type for C.
|
||||
float,
|
||||
/// The scalar type for D.
|
||||
float,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, float, float, float>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,84 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template implementing matrix multiply-add operations on fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename ScalarA_,
|
||||
typename ScalarB_,
|
||||
typename ScalarC_>
|
||||
struct ThreadMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,161 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of WMMA GEMM's epilogue phase.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_shared_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct WmmaGemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of WMMAs in the H dimension.
|
||||
static int const kWmmasPerH =
|
||||
GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
|
||||
/// The number of iterations in the epilogue. That's the number of "horizontal" WMMAs.
|
||||
typedef Shape<1, 1, kWmmasPerH> Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0, 0, 0> Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef WmmaGemmSharedStoreTileDTraits<
|
||||
// The output layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef WmmaGemmSharedLoadTileDTraits<
|
||||
// The pointer.
|
||||
typename Functor::Scalar,
|
||||
// The tile size.
|
||||
typename SharedStoreIteratorD::Tile,
|
||||
// The number of threads.
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,211 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines tile iterator traits for loading thread block-level tile from global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
|
||||
struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Base::Threads::kW;
|
||||
int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef WmmaGemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The traits.
|
||||
typedef TileTraits_ Traits;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<Traits,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset functor.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_h, inc_advance;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
/// The strides to increment the predicate offset.
|
||||
Index predicate_inc_h, predicate_inc_advance;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
stride_h = ld;
|
||||
// Each thread output 1 column per iteration. .
|
||||
inc_h = ld * TileTraits_::Threads::kH;
|
||||
inc_advance = inc_h + epilogue_stride_w;
|
||||
|
||||
predicate_offset = n;
|
||||
predicate_inc_h = TileTraits_::Threads::kH;
|
||||
predicate_inc_advance = predicate_inc_h + epilogue_delta_w;
|
||||
|
||||
// It worked.
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int const pointer_offset = 0,
|
||||
int const pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
|
||||
: params(params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + pointer_offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the predicate.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,108 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayoutA_,
|
||||
typename ScalarA_,
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
typename ScalarB_,
|
||||
MatrixLayout::Kind kLayoutC_,
|
||||
typename ScalarC_,
|
||||
typename AccumulatorsPerWarp_,
|
||||
typename InstructionShape_>
|
||||
struct WmmaGemmMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef InstructionShape_ InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
|
||||
/// The dimensions.
|
||||
typedef AccumulatorsPerWarp_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,240 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterator traits for efficiently loading and storing fragment to and from shared
|
||||
memory, specialized for WMMA GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template <class>
|
||||
struct Debug {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileATraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp % Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileBTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp / Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename WmmaShape_,
|
||||
int kSkew_ = 0>
|
||||
struct WmmaGemmSharedStoreTileDTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kC;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The pointer.
|
||||
typedef Scalar* Pointer;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// The tile with skew.
|
||||
typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The starting column.
|
||||
int const h = warp / Warps::kW * WmmaShape::kH;
|
||||
// The w.
|
||||
int const w = warp % Warps::kW * WmmaShape::kW;
|
||||
// The offset.
|
||||
int const offset = h * Tile::kW + w;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_>
|
||||
struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerLds_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_>
|
||||
ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
|
||||
Iterations;
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The offset.
|
||||
int const offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,574 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of GEMM targeting WMMA API in CUDA.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_>
|
||||
struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarC_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
WmmaGemmMultiplyAdd<kLayoutA_,
|
||||
half,
|
||||
kLayoutB_,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of stages in shared memory.
|
||||
1> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
|
||||
/// The shared tile size.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD,
|
||||
GemmConfig_::OutputTile::kW + kSkew>
|
||||
Tile;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
typename Base::MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename Base::GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The number of scalars loaded per iteration.
|
||||
static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The output tile size.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kW,
|
||||
GemmConfig_::OutputTile::kD + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
|
||||
/// The shared tile size.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD,
|
||||
GemmConfig_::OutputTile::kH + kSkew>
|
||||
Tile;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
typename Base::MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^T.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename Base::GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The number of scalars loaded per iteration.
|
||||
static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The output tile size.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// A is row-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kH,
|
||||
GemmConfig_::OutputTile::kD + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The index.
|
||||
typename Index_>
|
||||
struct WmmaGemmTraitsHelper {
|
||||
/// The WMMA GEMM config.
|
||||
typedef WmmaGemmConfig<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
|
||||
/// The GEMM config for A.
|
||||
typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperA::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperB::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
|
||||
/// The functor to do the multiply-add in the main loop.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The helper to create the epilogue traits.
|
||||
typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
|
||||
GemmEpilogueTraits;
|
||||
/// The epilogue.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename DefaultShape_ = Shape<64, 32, 64> >
|
||||
struct WmmaGemmAccumulatorsPerWarp {
|
||||
typedef typename ShapeMin<OutputTile_, DefaultShape_>::Shape Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_ = Shape<64, 128, 128>,
|
||||
/// The output type.
|
||||
typename ScalarC_ = float,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_ = ScalarC_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_ = Shape<16, 16, 16>,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 8,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 8,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
struct WmmaGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1,318 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Free functions for loading and storing to implementations of tile iteartor concepts.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_load_store.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/shape.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.get(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < InputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
}
|
||||
}
|
||||
if (h < InputIterator::Iterations::kH - 1) {
|
||||
iterator.inc_h();
|
||||
}
|
||||
}
|
||||
if (d < InputIterator::Iterations::kD - 1) {
|
||||
iterator.inc_d();
|
||||
}
|
||||
}
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator, masked by a predicate iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = InputIterator::Tile::kC *
|
||||
(w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
|
||||
|
||||
Load<typename Fragment::Element, InputIterator::Tile::kC, InputIterator::kMemorySpace>::
|
||||
load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
InputIterator iterator(_iterator);
|
||||
iterator_load_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
|
||||
typename OutputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
if (iterator.valid(d, h, w, 0)) {
|
||||
iterator.set(reinterpret_cast<typename OutputIterator::AccessType const &>(
|
||||
frag_iterator.at(d, h, w, 0)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
0);
|
||||
}
|
||||
if (w < OutputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
}
|
||||
}
|
||||
if (h < OutputIterator::Iterations::kH - 1) {
|
||||
iterator.inc_h();
|
||||
}
|
||||
}
|
||||
if (d < OutputIterator::Iterations::kD - 1) {
|
||||
iterator.inc_d();
|
||||
}
|
||||
}
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Stores a fragment to a shared memory output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) {
|
||||
typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename OutputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentStore<OutputIterator::kIteratorFragment,
|
||||
OutputIterator::Tile::kC,
|
||||
typename OutputIterator::Scalar,
|
||||
OutputIterator::kMemorySpace,
|
||||
typename OutputIterator::FragmentElement,
|
||||
OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = OutputIterator::Tile::kC *
|
||||
(w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
|
||||
|
||||
Store<typename Fragment::Element,
|
||||
OutputIterator::Tile::kC,
|
||||
OutputIterator::kMemorySpace>::
|
||||
store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
OutputIterator iterator(_iterator);
|
||||
iterator_store_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,222 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for efficiently loading and storing vectors to memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Enum to specify which memory space data resides in.
|
||||
*/
|
||||
struct MemorySpace {
|
||||
enum Kind {
|
||||
kGeneric, // Data accessed through pointer dereferencing
|
||||
kShared, // Data resides in shared memory
|
||||
kGlobal // Data resides in global memory
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
struct Load {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst = reinterpret_cast<AccessType const*>(&pointer[offset])[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<double, 2, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
|
||||
double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
|
||||
dst[0] = tmp.x;
|
||||
dst[1] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
|
||||
// WAR bug in NVCC where the upper and lower half of the register end up being the same
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<half, 8, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<half, 8>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
|
||||
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
|
||||
tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
|
||||
dst.registers[2] = tmp.x;
|
||||
dst.registers[3] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
dst.registers[2] = tmp.z;
|
||||
dst.registers[3] = tmp.w;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
struct Store {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
pointer[offset] = src;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
|
||||
addr[0] = src.registers[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
|
||||
addr[0] = make_uint2(src.registers[0], src.registers[1]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Store<double, 2, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, double* pointer, int offset) {
|
||||
double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
|
||||
addr[0] = make_double2(src[0], src[1]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
|
||||
addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,58 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a type for restructuring a tile.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/shape.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functor reshapes a tile of data. The goal is to have at least kAccessSize in
|
||||
// the inner-most dimension. If the user respects that constraint, there is nothing to be done. If
|
||||
// that's not the case, this functor will correct that and "extract" the right number of elements
|
||||
// from the next dimension.
|
||||
|
||||
template <typename Tile_, int kAccessSize_, bool = (Tile_::kC < kAccessSize_)>
|
||||
struct ReshapeTile {
|
||||
typedef Tile_ Tile;
|
||||
};
|
||||
|
||||
template <typename Tile_, int kAccessSize_>
|
||||
struct ReshapeTile<Tile_, kAccessSize_, true> {
|
||||
// Make sure the W dimension of the tile is large enough.
|
||||
static_assert(Tile_::kW >= kAccessSize_, "The W dimension is too small");
|
||||
// Make sure the dimension can be divided by the number of scalars.
|
||||
static_assert(Tile_::kW % kAccessSize_ == 0, "Not supported");
|
||||
// Collapse the W dimension.
|
||||
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / kAccessSize_, kAccessSize_> Tile;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
305
cutlass/shape.h
305
cutlass/shape.h
@ -1,305 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup layout_concept Layout Concept
|
||||
* @{
|
||||
* @par Implementations of \ref layout_concept are used to describe a cube with DxHxW elements and C
|
||||
scalars per element.
|
||||
A HxW slice of a cube is called an image and a cube consists of D images.
|
||||
*
|
||||
* @par Notations
|
||||
* Let Layout be an implementation of the \ref layout_concept.
|
||||
*
|
||||
* @par Valid Expressions
|
||||
* - <b>Layout::D</b> specifies the depth of a cube
|
||||
* - <b>Layout::H</b> specifies the height of a cube
|
||||
* - <b>Layout::W</b> specifies the height of a cube
|
||||
* - <b>Layout::C</b> specifies the number of channels of each element in a cube
|
||||
* - <b>Layout::W_c</b> specifies the number of scalars of each row in one image of a cube.
|
||||
* - <b>Layout::H_w</b> specifies the number of elements in an image slice.
|
||||
* - <b>Layout::H_w_c</b>_specifies the number of scalars in an image slice.
|
||||
* - <b>Layout::D_h_w</b> specifies the number of elements in a cube.
|
||||
* - <b>Layout::D_h_w_c</b> specifies the number of scalars in a cube.
|
||||
* - <b>Layout::Strides</b> is a \ref layout_concept specifying the strides.
|
||||
* @}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube.
|
||||
* @concept{layout_concept}
|
||||
*/
|
||||
template <int kD_ = 1, int kH_ = 1, int kW_ = 1, int kC_ = 1>
|
||||
struct Shape {
|
||||
/// The depth of the cube.
|
||||
static int const kD = kD_;
|
||||
/// The height of the cube.
|
||||
static int const kH = kH_;
|
||||
/// The width of the cube.
|
||||
static int const kW = kW_;
|
||||
/// The number of scalars per element.
|
||||
static int const kC = kC_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute derived counted of a \ref layout_concept based class
|
||||
*/
|
||||
template <typename Shape>
|
||||
struct ShapeCount {
|
||||
/// The number of elements per row.
|
||||
static int const kWc = Shape::kW * Shape::kC;
|
||||
/// The number of pixels per image.
|
||||
static int const kHw = Shape::kH * Shape::kW;
|
||||
/// The number of elements per image.
|
||||
static int const kHwc = Shape::kH * kWc;
|
||||
/// The number of pixels per cube.
|
||||
static int const kDhw = Shape::kD * kHw;
|
||||
/// The number of elements in the 4D space.
|
||||
static int const kDhwc = Shape::kD * kHwc;
|
||||
/// The number of elements in the 4D space.
|
||||
static int const kCount = kDhwc;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, int kScale_>
|
||||
struct ShapeScale {
|
||||
typedef Shape<A_::kD * kScale_, A_::kH * kScale_, A_::kW * kScale_, A_::kC * kScale_> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeAdd {
|
||||
typedef Shape<A_::kD + B_::kD, A_::kH + B_::kH, A_::kW + B_::kW, A_::kC + B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeSub {
|
||||
typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMul {
|
||||
typedef Shape<A_::kD * B_::kD, A_::kH * B_::kH, A_::kW * B_::kW, A_::kC * B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeDiv {
|
||||
typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMax {
|
||||
typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
|
||||
(A_::kH > B_::kH ? A_::kH : B_::kH),
|
||||
(A_::kW > B_::kW ? A_::kW : B_::kW),
|
||||
(A_::kC > B_::kC ? A_::kC : B_::kC)>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMin {
|
||||
typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
|
||||
(A_::kH < B_::kH ? A_::kH : B_::kH),
|
||||
(A_::kW < B_::kW ? A_::kW : B_::kW),
|
||||
(A_::kC < B_::kC ? A_::kC : B_::kC)>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_, int kElementsPerAccess>
|
||||
struct ShapeStrides {
|
||||
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
|
||||
Shape_::kW * Shape_::kC,
|
||||
Shape_::kC,
|
||||
kElementsPerAccess>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube
|
||||
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
|
||||
*/
|
||||
template <typename Shape_>
|
||||
struct ComputeOffsetFromShape {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
// clang-format off
|
||||
return d * Shape_::kH * Shape_::kW * Shape_::kC +
|
||||
h * Shape_::kW * Shape_::kC +
|
||||
w * Shape_::kC +
|
||||
c;
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
* @tparam kSc Separation between two elements in "elements"
|
||||
*/
|
||||
template <int kSh_, int kSw_, int kSc_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, kSc_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * kSw_ * kSc_ + w * kSc_ + c;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
*/
|
||||
template <int kSh_, int kSw_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * kSw_ + w; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube
|
||||
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
|
||||
*/
|
||||
template <typename Strides_>
|
||||
struct ComputeOffsetFromStrides {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
* @tparam S_c Stride between two scalars.
|
||||
*/
|
||||
template <int S_h_, int S_w_, int S_c_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, S_c_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * S_h_ + w * S_w_ + c * S_c_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
*/
|
||||
template <int S_h_, int S_w_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * S_h_ + w * S_w_; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_.
|
||||
* Afterwards compute the offset of those coordinates using Strides_
|
||||
* @tparam Threads_ The dimension of the cube the threadIdx.x value is mapped on
|
||||
* @tparam Strides_ The strides to use when compute the offsets based on the coordinates of the cube.
|
||||
*/
|
||||
template <typename Threads_, typename Strides_>
|
||||
struct ComputeThreadOffsetFromStrides {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int c = threadIdx.x % Threads_::kC;
|
||||
int w = threadIdx.x / Threads_::kC % Threads_::kW;
|
||||
int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
|
||||
int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
|
||||
|
||||
// Compute the offset.
|
||||
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*@brief Specialization for D=1
|
||||
*/
|
||||
template <int T_h_, int T_w_, int T_c_, int S_h_, int S_w_, int S_c_>
|
||||
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, T_c_>, Shape<1, S_h_, S_w_, S_c_> > {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int c = threadIdx.x % T_c_;
|
||||
int w = threadIdx.x / T_c_ % T_w_;
|
||||
int h = threadIdx.x / T_c_ / T_w_ % T_h_;
|
||||
|
||||
// Compute the offset.
|
||||
return h * S_h_ + w * S_w_ + c * S_c_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
*@brief Specialization for D=1 and C=1
|
||||
*/
|
||||
template <int T_h_, int T_w_, int S_h_, int S_w_>
|
||||
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, 1>, Shape<1, S_h_, S_w_, 1> > {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int w = threadIdx.x % T_w_;
|
||||
int h = threadIdx.x / T_w_;
|
||||
|
||||
// Compute the offset.
|
||||
return h * S_h_ + w * S_w_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,151 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure modeling a pointer and stride into a tensor
|
||||
template <typename Storage_, int Rank_>
|
||||
class TensorRef {
|
||||
public:
|
||||
/// Data type of individual access
|
||||
typedef Storage_ Storage;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = Rank_;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to storage element
|
||||
Storage* ptr_;
|
||||
|
||||
/// Stride information
|
||||
Coord<Rank> stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef() : ptr_(nullptr) {}
|
||||
|
||||
/// Constructs from a pointer, size, and stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, Coord<Rank> stride) : ptr_(ptr), stride_(stride) {}
|
||||
|
||||
/// Updates the pointer, stride, and location within a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr = nullptr, Coord<Rank> stride = Coord<Rank>(0)) {
|
||||
ptr_ = ptr;
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
/// Conversion function
|
||||
template <typename T>
|
||||
TensorRef<T, Rank> convert() {
|
||||
Coord<Rank> converted_stride;
|
||||
for (int i = 0; i < Rank - 1; ++i) {
|
||||
converted_stride[i] = stride_[i] * Extent<Storage>::kValue / Extent<T>::kValue;
|
||||
}
|
||||
converted_stride[Rank - 1] = stride_[Rank - 1];
|
||||
|
||||
return TensorRef<T, Rank>(reinterpret_cast<T*>(ptr_), converted_stride);
|
||||
}
|
||||
|
||||
/// Returns true if the TensorRef may be safely accessed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ptr_ != nullptr; }
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage* data() const { return ptr_; }
|
||||
|
||||
/// Returns the stride of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank> const& stride() const { return stride_; }
|
||||
|
||||
/// Returns the stride of the tensor in the given dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return stride_.at(dim); }
|
||||
|
||||
/// Returns the maximum stride element as the 'leading dimension'
|
||||
CUTLASS_HOST_DEVICE
|
||||
int leading_dim() const { return __NV_STD_MAX(stride_[1], stride_[2]); }
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
long long offset(Coord<Rank> const& coord) const {
|
||||
return stride_.template dot<long long>(coord);
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(Coord<Rank> const& coord) const { return ptr_[offset(coord)]; }
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(int idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](int idx) const { return at(idx); }
|
||||
|
||||
/// Adds an offset to the pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& advance(Coord<Rank> const& b) {
|
||||
ptr_ += offset(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator+(Coord<Rank> const& b) const { return TensorRef(ptr_ + offset(b), stride_); }
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator-(Coord<Rank> const& b) const { return TensorRef(ptr_ - offset(b), stride_); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,172 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides and a pointer to tensor data.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Host-side reference implementation of tensor operations
|
||||
template <typename T>
|
||||
class TensorView : public TensorRef<T, 4> {
|
||||
public:
|
||||
/// Reference and stride
|
||||
typedef TensorRef<T, 4> Base;
|
||||
|
||||
/// Reference and stride
|
||||
typedef Base TensorRef_t;
|
||||
|
||||
/// Reference to constant type
|
||||
typedef TensorRef<T const, 4> ConstTensorRef_t;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = TensorRef_t::Rank;
|
||||
|
||||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef int Offset_t;
|
||||
|
||||
/// Coordinate into tensor
|
||||
typedef Coord<Rank> Coord_t;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to pitch-linear memory
|
||||
TensorRef_t ref_;
|
||||
|
||||
/// Dimensions of coordinate (independent of stride)
|
||||
Coord_t size_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView() {}
|
||||
|
||||
/// Constructs a Tensor_view from a TensorRef and size
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView(TensorRef_t const& _ref, Coord_t const& _size) : Base(_ref), size_(_size) {}
|
||||
|
||||
/// Returns true if the Tensor_view is bound to some memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ref().good(); }
|
||||
|
||||
/// Returns a pointer to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
T* data() const { return ref().data(); }
|
||||
|
||||
/// Updates the reference and size of a Tensor_view object
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
|
||||
Base::operator=(_ref);
|
||||
size_ = _size;
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t& ref() { return *this; }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef_t const_ref() { return ConstTensorRef_t(data(), stride()); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t const& ref() const { return *this; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& size() const { return size_; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
int size(int dim) const { return size_.at(dim); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& stride() const { return ref().stride(); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return ref().stride(dim); }
|
||||
|
||||
/// Assigns the Tensor_view
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator=(TensorView const& _tensor) {
|
||||
Base::operator=(_tensor._ref);
|
||||
size_ = _tensor.size_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the index of an element
|
||||
CUTLASS_HOST_DEVICE
|
||||
Offset_t offset(Coord_t const& coord) const { return ref().offset(coord); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool contains(Coord_t const& coord) const {
|
||||
for (int dim = 0; dim < Rank; ++dim) {
|
||||
if (coord.at(dim) >= size_.at(dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Coord_t const& coord) const { return ref().at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Offset_t idx) const { return ref().at(idx); }
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView<T> subview(Coord_t const& location, Coord_t size) const {
|
||||
return TensorView<T>(ref() + location, size.clamp(size_ - location));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,899 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines the Tile Traits concept and iterators for loading and storing to tiles
|
||||
efficiently.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_traits_concept Tile Traits Concept
|
||||
@{
|
||||
|
||||
\ref tile_traits_concept is a type definining the shape of a tile and the distribution of accesses
|
||||
by individual entities, either threads or other.
|
||||
|
||||
@par Tile Traits Concept
|
||||
Types satisfying \ref tile_traits_concept define the following members
|
||||
- <b>Tile</b> - a type satisfying \ref layout_concept describing the dimensions of the tile
|
||||
- <b>Delta</b> - a type satisfying \ref layout_concept describing the increments between accesses
|
||||
along each dimension
|
||||
- <b>Iterations</b> - a type satisfying \ref layout_concept describing the number of accesses
|
||||
along each dimension
|
||||
- <b>Offset</b> - the type of a <i>functor</i> computing the offset of each participating entity
|
||||
as a Coord<4>.
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specifies dimension in which post-increment accesses advance
|
||||
struct IteratorAdvance {
|
||||
enum Kind { kD, kH, kW };
|
||||
};
|
||||
|
||||
/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix
|
||||
struct IteratorFragment {
|
||||
enum Kind { kScalar, kWmmaMatrix };
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref tile_traits_concept
|
||||
* @concept{tile_traits_concept}
|
||||
*/
|
||||
template <typename Tile_,
|
||||
typename Delta_,
|
||||
typename Iterations_,
|
||||
typename ThreadOffset_,
|
||||
int kAccessSize>
|
||||
struct TileTraits {
|
||||
/// Shape of the tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of steps between accesses along each dimension
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// Number of accesses performed
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Functor that returns the logical coordinate of each entity's initial offset in the tile
|
||||
typedef ThreadOffset_ ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for accessing a stripmined tile in memory
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileIteratorBase {
|
||||
/// concept TileTraits
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef FragmentElement_ FragmentElement;
|
||||
|
||||
/// Specifies dimension in which post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Advance_;
|
||||
|
||||
/// Specifies iterator storage fragment type (Scalar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = IteratorFragment_;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef Skew_ Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Traits::Tile Tile;
|
||||
|
||||
/// Distance along each dimension
|
||||
typedef typename Traits::Delta Delta;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// Thread offset
|
||||
typedef typename Traits::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The number of scalars accessed per load/store.
|
||||
static int const kAccessSize = Tile::kC;
|
||||
|
||||
/// The elements loaded/store by one instruction.
|
||||
typedef typename Vectorize<FragmentElement, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The size of storage needed per fragment
|
||||
static int const kFragmentSize =
|
||||
(kIteratorFragment == IteratorFragment::kWmmaMatrix ? 16 : sizeof(AccessType));
|
||||
/// The storage.
|
||||
typedef Fragment<Scalar, ShapeCount<Tile>::kCount, kFragmentSize> Storage;
|
||||
/// The fragment.
|
||||
typedef Fragment<FragmentElement, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
/// The shape of the fragment.
|
||||
typedef typename FragmentIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
//
|
||||
// Params struct
|
||||
//
|
||||
|
||||
/// Parameters to the iterator
|
||||
struct Params {
|
||||
Index stride_d;
|
||||
Index stride_h;
|
||||
Index stride_w;
|
||||
|
||||
Index inc_d;
|
||||
Index inc_h;
|
||||
Index inc_w;
|
||||
|
||||
Index inc_advance;
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_d = _inc_d;
|
||||
inc_h = _inc_h;
|
||||
inc_w = _inc_w;
|
||||
inc_advance = _inc_advance;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_w = stride_w * Delta::kW;
|
||||
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
// Advance in the H dimension.
|
||||
inc_d = 0;
|
||||
} else if (kAdvance == IteratorAdvance::kW) {
|
||||
// Advance in the W dimension.
|
||||
inc_d = stride_w * Tile::kW - stride_h * Tile::kH;
|
||||
} else {
|
||||
// Advance in the D dimension.
|
||||
inc_d = stride_d;
|
||||
}
|
||||
|
||||
inc_advance = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize() {
|
||||
stride_d = 0;
|
||||
stride_h = 0;
|
||||
stride_w = 1;
|
||||
|
||||
inc_d = inc_h = inc_w = inc_advance = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &offset = make_Coord(0, 0, 0)) {
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
bool enable_d = (d * Delta::kD + offset[0] < bounds[0]);
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
bool enable_h = (h * Delta::kH + offset[1] < bounds[1]);
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]);
|
||||
predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_load_iterator_concept Tile Load Iterator Concept
|
||||
@{
|
||||
|
||||
\ref tile_load_iterator_concept enables loading a tile from addressable memory into a fragment
|
||||
|
||||
@par Tile Load Iterator Concept
|
||||
Types satisfying \ref tile_load_iterator_concept define the following members
|
||||
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
|
||||
each access implied by the tile traits
|
||||
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
|
||||
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
|
||||
vector according to externally specified bounds
|
||||
- <b>load_post_increment(fragment, pred_it)</b> - a method that loads a fragment and increments
|
||||
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
|
||||
- <b>load_post_increment(fragment)</b> - a method that loads a fragment and increments the
|
||||
iterator to the next tile
|
||||
- <b>load(fragment, pred_it)</b> - a const method that loads a fragment, guarded by a \ref
|
||||
predicate_iterator_concept
|
||||
- <b>load(fragment)</b> - a method that loads a fragment
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief An iterator implementing \ref tile_load_iterator_concept for loading a tile from memory
|
||||
* @concept{tile_load_iterator_concept}
|
||||
*/
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
/// concept TileTraits
|
||||
typedef typename Base::Traits Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef typename Base::FragmentElement FragmentElement;
|
||||
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Base::Tile Tile;
|
||||
|
||||
/// Delta
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
/// ThreadOffset functor
|
||||
typedef typename Base::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Fragment type
|
||||
typedef typename Base::FragmentShape FragmentShape;
|
||||
|
||||
/// Memory access type
|
||||
typedef typename Base::AccessType AccessType;
|
||||
|
||||
/// Fragment definition
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Fragment iterator definition
|
||||
typedef typename Base::FragmentIterator FragmentIterator;
|
||||
|
||||
/// Fragment const iterator definition
|
||||
typedef typename Base::FragmentConstIterator FragmentConstIterator;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef typename Base::PredicateVector PredicateVector;
|
||||
|
||||
/// Storage object that may be loaded from
|
||||
typedef typename Base::Storage SharedStorage;
|
||||
|
||||
/// IteratorBase parameters
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// Do we require a fence?
|
||||
enum { kRequiresLoadFence = Tile::kD == 1 };
|
||||
|
||||
/// The pointer type
|
||||
typedef Scalar const *Pointer;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage const &storage) {
|
||||
pointer = &storage[0];
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
pointer = ptr;
|
||||
Base::Params::initialize(
|
||||
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes params to default values
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize() { return Base::Params::initialize(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
Params params;
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Stage argument enables wrapping after some number of tiles have been loaded.
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
0,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator() {}
|
||||
|
||||
/// Constructs a tile load iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator(Params const &_params,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
Index block_offset_h = 0;
|
||||
Index block_offset_w = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_offset_h = block_offset[1];
|
||||
block_offset_w = block_offset[2];
|
||||
} else {
|
||||
block_offset_h = block_offset[2];
|
||||
block_offset_w = block_offset[1];
|
||||
}
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
(block_offset_h + thread_offset[1]) * params.stride_h +
|
||||
(block_offset_w + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
}
|
||||
|
||||
/// Constructs a tile load iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *data() const { return params.pointer; }
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar, Base::kAccessSize, kMemorySpace>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
|
||||
/// Increment in the H dimension
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
|
||||
/// Increment in the W dimension
|
||||
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
|
||||
|
||||
/// Increment in the next dimension
|
||||
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
params.pointer -= (Tile::kD - 1) * kStageSize;
|
||||
stage = 0;
|
||||
} else {
|
||||
params.pointer += kStageSize;
|
||||
stage = stage + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Load<typename Fragment::Element, Tile::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
}
|
||||
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
TileLoadIterator _load_it(*this);
|
||||
_load_it.load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load(fragment, pred_it);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_store_iterator_concept Tile Store Iterator Concept
|
||||
@{
|
||||
|
||||
\ref tile_store_iterator_concept enables storing a tile to addressable memory
|
||||
|
||||
@par Tile Store Iterator Concept
|
||||
Types satisfying \ref tile_load_iterator_concept define the following members
|
||||
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
|
||||
each access implied by the tile traits
|
||||
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
|
||||
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
|
||||
vector according to externally specified bounds
|
||||
- <b>store_post_increment(fragment, pred_it)</b> - a method that stores a fragment and increments
|
||||
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
|
||||
- <b>store_post_increment(fragment)</b> - a method that stores a fragment and increments the
|
||||
iterator to the next tile
|
||||
- <b>store(fragment, pred_it)</b> - a const method that stores a fragment, guarded by a \ref
|
||||
predicate_iterator_concept
|
||||
- <b>store(fragment)</b> - a method that loads a fragment
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief An iterator implementing \ref tile_store_iterator_concept for storing a tile to memory
|
||||
* @concept{tile_store_iterator_concept}
|
||||
*/
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
/// concept TileTraits
|
||||
typedef typename Base::Traits Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef typename Base::FragmentElement FragmentElement;
|
||||
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Base::Tile Tile;
|
||||
|
||||
/// Delta
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
/// ThreadOffset functor
|
||||
typedef typename Base::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Fragment type
|
||||
typedef typename Base::FragmentShape FragmentShape;
|
||||
|
||||
/// Memory access type
|
||||
typedef typename Base::AccessType AccessType;
|
||||
|
||||
/// Fragment definition
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Fragment iterator definition
|
||||
typedef typename Base::FragmentIterator FragmentIterator;
|
||||
|
||||
/// Fragment const iterator definition
|
||||
typedef typename Base::FragmentConstIterator FragmentConstIterator;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef typename Base::PredicateVector PredicateVector;
|
||||
|
||||
/// Storage object which may be stored to
|
||||
typedef typename Base::Storage SharedStorage;
|
||||
|
||||
/// IteratorBase parameters
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar *pointer;
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage &storage) {
|
||||
pointer = &storage[0];
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
pointer = ptr;
|
||||
Base::Params::initialize(
|
||||
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to default values
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize() { return Base::Params::initialize(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
Params params;
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// The stage.
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
0,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator() {}
|
||||
|
||||
/// Constructs a tile store iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator(Params const &_params,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
(block_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(block_offset[2] + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
}
|
||||
|
||||
/// Constructs a tile store iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar *data() const { return params.pointer; }
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
|
||||
/// Increment in the H dimension
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
|
||||
/// Increment in the W dimension
|
||||
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
|
||||
|
||||
/// Increment in the next dimension
|
||||
CUTLASS_HOST_DEVICE void inc_advance() {}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
params.pointer -= (Tile::kD - 1) * kStageSize;
|
||||
stage = 0;
|
||||
} else {
|
||||
params.pointer += kStageSize;
|
||||
stage = stage + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(AccessType const &value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar, Base::kAccessSize, kMemorySpace>::store(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Store<typename Fragment::Element, Tile::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
}
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
TileStoreIterator _store_it(*this);
|
||||
_store_it.store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store(fragment, pred_it);
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -1,238 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines tile traits for several tile partitioning arrangements of threads expected to
|
||||
achieve efficient streaming performance.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Basic thread offset function computed from a thread shape
|
||||
template <typename ThreadShape>
|
||||
struct TiledThreadOffset {
|
||||
/// Computes the logical coordinate from thread shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
Coord<4> thread_offset;
|
||||
|
||||
int index = threadIdx.x;
|
||||
|
||||
thread_offset[3] = (index % ThreadShape::kC);
|
||||
index = (index / ThreadShape::kC);
|
||||
|
||||
thread_offset[2] = (index % ThreadShape::kW);
|
||||
index = (index / ThreadShape::kW);
|
||||
|
||||
thread_offset[1] = (index % ThreadShape::kH);
|
||||
index = (index / ThreadShape::kH);
|
||||
|
||||
thread_offset[0] = index;
|
||||
|
||||
return thread_offset;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which the number of threads is greater than the
|
||||
/// contiguous dimension of the tile.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsStrideMajor {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(Tile::kW <= kThreads,
|
||||
"This specialization assumes there are more threads than the contiguous dimension "
|
||||
"of the tile.");
|
||||
|
||||
/// Shape of threads
|
||||
typedef Shape<1, kThreads / Tile::kW, Tile::kW, 1> ThreadShape;
|
||||
|
||||
/// Delta along each dimension
|
||||
typedef Shape<1, ThreadShape::kH, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH / ThreadShape::kH, 1, 1> Iterations;
|
||||
|
||||
/// Computes the initial offset
|
||||
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which the number of threads is fewer than the tile size
|
||||
/// in the contiguous dimension.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsContiguousMajor {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
// Static assertions
|
||||
static_assert(Tile::kW >= kThreads,
|
||||
"This specialization assumes there are more threads than the contiguous dimension "
|
||||
"of the tile.");
|
||||
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(!(Tile::kW % kThreads),
|
||||
"The contiguous size of the tile must be divisible by the number of threads.");
|
||||
|
||||
/// Thread shape
|
||||
typedef Shape<1, 1, kThreads> ThreadShape;
|
||||
|
||||
/// Delta between each thread's access
|
||||
typedef Shape<1, 1, kThreads> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH, Tile::kW / kThreads> Iterations;
|
||||
|
||||
/// Computes the initial offset
|
||||
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which warps rake across the contiguous dimension
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsWarpRake {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
/// Hard-coded warp size
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(!(kThreads % kWarpSize), "Number of threads must be divisible by the warp size.");
|
||||
|
||||
static_assert(!(Tile::kW % kWarpSize), "Contiguous dimension must be divisible by the warp size");
|
||||
|
||||
/// Warps strip-mined across strided dimension
|
||||
static int const kWarpsStrided = __NV_STD_MIN(kWarpCount, Tile::kH);
|
||||
|
||||
/// Warps stripmined contiguous dimension
|
||||
static int const kWarpsContiguous = kWarpCount / kWarpsStrided;
|
||||
|
||||
/// Arrangement of threads
|
||||
typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape;
|
||||
|
||||
/// The same warp rakes along the contiguous dimension
|
||||
typedef Shape<1, kWarpsStrided, kWarpSize> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
/// Basic thread offset function computed from a thread shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int tid = threadIdx.x;
|
||||
int warp = (tid / kWarpSize);
|
||||
int lane = (tid % kWarpSize);
|
||||
|
||||
static int const kWarpSpanContiguous = kWarpSize * Iterations::kW;
|
||||
|
||||
int warp_w = (warp % kWarpsContiguous);
|
||||
int warp_h = (warp / kWarpsContiguous);
|
||||
|
||||
return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Chooses 'best' shape to enable warp raking along contiguous dimension if possible.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsStandard {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
/// Hard-coded warp size
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
/// Choose the stride-major contiguous tiling if the contiguous dimension is
|
||||
/// smaller than the warp size. Otherwise, if it is divisible by the warp size,
|
||||
/// choose the warp rake arrangement.
|
||||
typedef typename platform::conditional <
|
||||
Tile::kW<kWarpSize,
|
||||
TileTraitsStrideMajor<Tile, Threads>,
|
||||
typename platform::conditional<!(Tile::kW % kWarpSize),
|
||||
TileTraitsWarpRake<Tile, Threads>,
|
||||
TileTraitsContiguousMajor<Tile, Threads> >::type>::
|
||||
type Traits;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef typename Traits::Delta Delta;
|
||||
|
||||
/// Delta between each thread's access
|
||||
/// TODO MTA this is wrong for sure, but Delta is used for stride computation at the moment
|
||||
typedef Delta ImmediateOffsetStrides;
|
||||
|
||||
/// Number of accesses
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// Thread offset functor
|
||||
typedef typename Traits::ThreadOffset ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,131 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Math utilities
|
||||
*/
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* Static math utilities
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Statically determine if N is a power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded down
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_down {
|
||||
/// Static logarithm value
|
||||
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_down<N, 1, Count> {
|
||||
enum { value = Count };
|
||||
};
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded up
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_up {
|
||||
/// Static logarithm value
|
||||
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_up<N, 1, Count> {
|
||||
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
|
||||
};
|
||||
|
||||
/**
|
||||
* Statically estimate sqrt(N) to the nearest power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct sqrt_est {
|
||||
enum { value = 1 << (log2_up<N>::value / 2) };
|
||||
};
|
||||
|
||||
/**
|
||||
* For performing a constant-division with a compile-time assertion that the
|
||||
* Divisor evenly-divides the Dividend.
|
||||
*/
|
||||
template <int Dividend, int Divisor>
|
||||
struct divide_assert {
|
||||
enum { value = Dividend / Divisor };
|
||||
|
||||
static_assert((Dividend % Divisor == 0), "Not an even multiple");
|
||||
};
|
||||
|
||||
/******************************************************************************
|
||||
* Rounding
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Round dividend up to the nearest multiple of divisor
|
||||
*/
|
||||
template <typename dividend_t, typename divisor_t>
|
||||
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
|
||||
return ((dividend + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Greatest common divisor
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
|
||||
for (;;) {
|
||||
if (a == 0) return b;
|
||||
b %= a;
|
||||
if (b == 0) return a;
|
||||
a %= b;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Least common multiple
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
|
||||
value_t temp = gcd(a, b);
|
||||
|
||||
return temp ? (a / temp * b) : 0;
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,122 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Debugging and logging functionality
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* Debug and logging macros
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Formats and prints the given message to stdout
|
||||
*/
|
||||
#if !defined(CUDA_LOG)
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
|
||||
#else
|
||||
#define CUDA_LOG(format, ...) \
|
||||
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
|
||||
blockIdx.x, \
|
||||
blockIdx.y, \
|
||||
blockIdx.z, \
|
||||
threadIdx.x, \
|
||||
threadIdx.y, \
|
||||
threadIdx.z, \
|
||||
__VA_ARGS__);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Formats and prints the given message to stdout only if DEBUG is defined
|
||||
*/
|
||||
#if !defined(CUDA_LOG_DEBUG)
|
||||
#ifdef DEBUG
|
||||
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
|
||||
#else
|
||||
#define CUDA_LOG_DEBUG(format, ...)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code)
|
||||
* along with the supplied source context.
|
||||
*
|
||||
* \return The CUDA error.
|
||||
*/
|
||||
__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
|
||||
const char* filename,
|
||||
int line) {
|
||||
(void)filename;
|
||||
(void)line;
|
||||
if (error) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
fprintf(
|
||||
stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
|
||||
fflush(stderr);
|
||||
#else
|
||||
printf("CUDA error %d [%s, %d]\n", error, filename, line);
|
||||
#endif
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perror macro
|
||||
*/
|
||||
#ifndef CUDA_PERROR
|
||||
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Perror macro with exit
|
||||
*/
|
||||
#ifndef CUDA_PERROR_EXIT
|
||||
#define CUDA_PERROR_EXIT(e) \
|
||||
if (cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)) { \
|
||||
exit(1); \
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Perror macro only if DEBUG is defined
|
||||
*/
|
||||
#ifndef CUDA_PERROR_DEBUG
|
||||
#ifdef DEBUG
|
||||
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
|
||||
#else
|
||||
#define CUDA_PERROR_DEBUG(e) (e)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace cutlass
|
||||
229
cutlass/vector.h
229
cutlass/vector.h
@ -1,229 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a 1D vector of elements held in the registers of each thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <size_t kAlignment_>
|
||||
struct AlignedStruct {};
|
||||
|
||||
template <>
|
||||
struct __align__(1) AlignedStruct<1>{};
|
||||
template <>
|
||||
struct __align__(2) AlignedStruct<2>{};
|
||||
template <>
|
||||
struct __align__(4) AlignedStruct<4>{};
|
||||
template <>
|
||||
struct __align__(8) AlignedStruct<8>{};
|
||||
template <>
|
||||
struct __align__(16) AlignedStruct<16>{};
|
||||
template <>
|
||||
struct __align__(32) AlignedStruct<32>{};
|
||||
template <>
|
||||
struct __align__(64) AlignedStruct<64>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_>
|
||||
union Vector {
|
||||
/// The scalar type.
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
// Make sure that the vector type makes sense.
|
||||
static_assert(kVectorSize <= 16, "Vector type is too large");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The associated array of scalars.
|
||||
Scalar scalars[kLanes];
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
|
||||
template <int kLanes_>
|
||||
union Vector<half, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef half Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
// Make sure that the vector type makes sense.
|
||||
static_assert(kVectorSize <= size_t(16), "Vector type is too large");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The associated array of scalars.
|
||||
uint16_t scalars[kLanes];
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const {
|
||||
return reinterpret_cast<Scalar const&>(scalars[i]);
|
||||
}
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return reinterpret_cast<Scalar&>(scalars[i]); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
CUTLASS_DEVICE void make_zero(Scalar_& x) {
|
||||
x = Scalar_(0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_, int kLanes_ = 1>
|
||||
struct Vectorize {
|
||||
typedef Vector<Element_, kLanes_> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_>
|
||||
struct Vectorize<Element_, 1> {
|
||||
typedef Element_ Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_>
|
||||
CUTLASS_DEVICE void make_zero(Vector<Scalar_, kLanes_>& vec) {
|
||||
for (int i = 0; i < Vector<Scalar_, kLanes_>::kRegisters; ++i) {
|
||||
vec.registers[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// cutlass::Extent similar to std::extent but applicable to CUTLASS types
|
||||
//
|
||||
|
||||
/// Returns the extent of a scalar or vector
|
||||
template <typename T>
|
||||
struct Extent {
|
||||
static size_t const kValue = 1;
|
||||
};
|
||||
|
||||
/// Returns the number of lanes of a vector if need be
|
||||
template <typename T, int Lanes>
|
||||
struct Extent<Vector<T, Lanes> > {
|
||||
static size_t const kValue = Lanes;
|
||||
};
|
||||
|
||||
/// Returns the number of lanes of a vector if need be
|
||||
template <typename T, int Lanes>
|
||||
struct Extent<Vector<T, Lanes> const> {
|
||||
static size_t const kValue = Lanes;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Traits describing properties of vectors and scalar-as-vectors
|
||||
template <typename T>
|
||||
struct VectorTraits {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = 1;
|
||||
|
||||
/// True if the type is actually a cutlass::Vector, otherwise false
|
||||
static bool const IsVector = false;
|
||||
|
||||
/// Type that is always a vector
|
||||
typedef Vector<T, 1> Vector;
|
||||
};
|
||||
|
||||
/// Partial specialization for actual cutlass::Vector
|
||||
template <typename T, int Lanes>
|
||||
struct VectorTraits<Vector<T, Lanes> > {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = Lanes;
|
||||
|
||||
/// Type is actually a cutlass::Vector
|
||||
static bool const IsVector = true;
|
||||
|
||||
/// Type that is always a Vector
|
||||
typedef Vector<T, Lanes> Vector;
|
||||
};
|
||||
|
||||
/// Partial specialization for actual cutlass::Vector
|
||||
template <typename T, int Lanes>
|
||||
struct VectorTraits<Vector<T, Lanes> const> {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = Lanes;
|
||||
|
||||
/// Type is actually a cutlass::Vector
|
||||
static bool const IsVector = true;
|
||||
|
||||
/// Type that is always a Vector
|
||||
typedef Vector<T, Lanes> Vector;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,193 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Abstractions for loading and storing matrices using the CUDA WMMA API.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
|
||||
|
||||
// Dependent header files should use the following macro to guard all code using
|
||||
// nvcuda::wmma:: to enable compilation for CUDA Compute Capabilities < sm_70.
|
||||
// Earlier shader models not support Tensor Cores.
|
||||
#define CUTLASS_USE_WMMA_API
|
||||
|
||||
#include "stdio.h"
|
||||
|
||||
#include <crt/mma.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
|
||||
template <MatrixLayout::Kind kLayout_>
|
||||
struct WmmaLayout {
|
||||
typedef nvcuda::wmma::col_major Layout;
|
||||
};
|
||||
|
||||
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
|
||||
template <>
|
||||
struct WmmaLayout<MatrixLayout::kRowMajor> {
|
||||
typedef nvcuda::wmma::row_major Layout;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment load and store operations
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaMatrix {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for A operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::matrix_a,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for B operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::matrix_b,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for C operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::accumulator,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
*this,
|
||||
pointer,
|
||||
stride,
|
||||
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
pointer,
|
||||
*this,
|
||||
stride,
|
||||
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
@ -1 +0,0 @@
|
||||
theme: jekyll-theme-minimal
|
||||
145
docs/aligned__buffer_8h.html
Normal file
145
docs/aligned__buffer_8h.html
Normal file
@ -0,0 +1,145 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: aligned_buffer.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">aligned_buffer.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared memory.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="cutlass_8h_source.html">cutlass/cutlass.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="array_8h_source.html">cutlass/array.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for aligned_buffer.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="aligned__buffer_8h__incl.png" border="0" usemap="#aligned__buffer_8h" alt=""/></div>
|
||||
<map name="aligned__buffer_8h" id="aligned__buffer_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="aligned__buffer_8h__dep__incl.png" border="0" usemap="#aligned__buffer_8hdep" alt=""/></div>
|
||||
<map name="aligned__buffer_8hdep" id="aligned__buffer_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="aligned__buffer_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1AlignedBuffer.html">cutlass::AlignedBuffer< T, N, Align ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Modifies semantics of cutlass::Array<> to provide guaranteed alignment. <a href="structcutlass_1_1AlignedBuffer.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/aligned__buffer_8h__dep__incl.md5
Normal file
1
docs/aligned__buffer_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
6cbc6b81ede44b5f08afd4f4519d56d1
|
||||
1
docs/aligned__buffer_8h__incl.md5
Normal file
1
docs/aligned__buffer_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
b26c62930ff7668b89f2ee6624e0be3a
|
||||
135
docs/aligned__buffer_8h_source.html
Normal file
135
docs/aligned__buffer_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
1113
docs/annotated.html
1113
docs/annotated.html
File diff suppressed because it is too large
Load Diff
156
docs/arch_2mma_8h.html
Normal file
156
docs/arch_2mma_8h.html
Normal file
@ -0,0 +1,156 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: mma.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">arch/mma.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Templates exposing architecture support for multiply-add operations.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="array_8h_source.html">cutlass/array.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="numeric__types_8h_source.html">cutlass/numeric_types.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="include_2cutlass_2gemm_2gemm_8h_source.html">cutlass/gemm/gemm.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="arch_2mma__sm50_8h_source.html">cutlass/arch/mma_sm50.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="arch_2mma__sm60_8h_source.html">cutlass/arch/mma_sm60.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="arch_2mma__sm61_8h_source.html">cutlass/arch/mma_sm61.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="mma__sm70_8h_source.html">cutlass/arch/mma_sm70.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="mma__sm75_8h_source.html">cutlass/arch/mma_sm75.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for arch/mma.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma_8h__incl.png" border="0" usemap="#mma_8h" alt=""/></div>
|
||||
<map name="mma_8h" id="mma_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma_8h__dep__incl.png" border="0" usemap="#mma_8hdep" alt=""/></div>
|
||||
<map name="mma_8hdep" id="mma_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="arch_2mma_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma.html">cutlass::arch::Mma< Shape_, kThreads_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01ElementAb6e65b2cf5ede7f41cb070a767158dee.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01ElementAb6e65b2cf5ede7f41cb070a767158dee.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/arch_2mma_8h__dep__incl.md5
Normal file
1
docs/arch_2mma_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
7d16b59e6ba0442b8a275a213d5da3a6
|
||||
1
docs/arch_2mma_8h__incl.md5
Normal file
1
docs/arch_2mma_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
d1fff3f9d55a262110aa6a456caa91e0
|
||||
122
docs/arch_2mma_8h_source.html
Normal file
122
docs/arch_2mma_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
176
docs/arch_2mma__sm50_8h.html
Normal file
176
docs/arch_2mma__sm50_8h.html
Normal file
@ -0,0 +1,176 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: mma_sm50.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">arch/mma_sm50.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Matrix multiply.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="arch_2mma_8h_source.html">cutlass/arch/mma.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="complex_8h_source.html">cutlass/complex.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="layout_2matrix_8h_source.html">cutlass/layout/matrix.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="include_2cutlass_2gemm_2gemm_8h_source.html">cutlass/gemm/gemm.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for arch/mma_sm50.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm50_8h__incl.png" border="0" usemap="#mma__sm50_8h" alt=""/></div>
|
||||
<map name="mma__sm50_8h" id="mma__sm50_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm50_8h__dep__incl.png" border="0" usemap="#mma__sm50_8hdep" alt=""/></div>
|
||||
<map name="mma__sm50_8hdep" id="mma__sm50_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="arch_2mma__sm50_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_004bb3fd76ca2af7b3210676fa9644d95b.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_004bb3fd76ca2af7b3210676fa9644d95b.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_0aa57e6a2e6b5da37d10688bf99419a23.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_0aa57e6a2e6b5da37d10688bf99419a23.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01int_00_00b2dff9ce8caad9aff5bc6a355539161.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01int_00_00b2dff9ce8caad9aff5bc6a355539161.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_76f9d24016e1b4167b16f4d7628c9546.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_76f9d24016e1b4167b16f4d7628c9546.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_f1c9d2ee842455cd0c5b71d56108d468.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, float, LayoutB, complex< float >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_f1c9d2ee842455cd0c5b71d56108d468.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_00e3e12e263df6506b8cf06c3f4d478b8e.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01float_00e3e12e263df6506b8cf06c3f4d478b8e.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_30fa42e1ad201df010637cd22fc070a1.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_30fa42e1ad201df010637cd22fc070a1.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_48b3a43bc03fff93a111ac01abe7e40d.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, double, LayoutB, complex< double >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01complex_48b3a43bc03fff93a111ac01abe7e40d.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_070b94670e040ed5855e5b42d5ca8a443.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01double_070b94670e040ed5855e5b42d5ca8a443.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01half__t_4f30ee91f7bb3844ff7579c68d078818.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_011_01_4_00_011_00_01half__t_4f30ee91f7bb3844ff7579c68d078818.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/arch_2mma__sm50_8h__dep__incl.md5
Normal file
1
docs/arch_2mma__sm50_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
988e6466c703c4e63c9a889b8c3c54b5
|
||||
1
docs/arch_2mma__sm50_8h__incl.md5
Normal file
1
docs/arch_2mma__sm50_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
03f1613fdffbd6e7575de0d2967d08bf
|
||||
129
docs/arch_2mma__sm50_8h_source.html
Normal file
129
docs/arch_2mma__sm50_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
157
docs/arch_2mma__sm60_8h.html
Normal file
157
docs/arch_2mma__sm60_8h.html
Normal file
@ -0,0 +1,157 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: mma_sm60.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">arch/mma_sm60.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Matrix multiply.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include <cuda_fp16.h></code><br />
|
||||
<code>#include "<a class="el" href="arch_2mma_8h_source.html">cutlass/arch/mma.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="layout_2matrix_8h_source.html">cutlass/layout/matrix.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for arch/mma_sm60.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm60_8h__incl.png" border="0" usemap="#mma__sm60_8h" alt=""/></div>
|
||||
<map name="mma__sm60_8h" id="mma__sm60_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm60_8h__dep__incl.png" border="0" usemap="#mma__sm60_8hdep" alt=""/></div>
|
||||
<map name="mma__sm60_8hdep" id="mma__sm60_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="arch_2mma__sm60_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_011_00_011_01_4_00_011_00_01half__t_8cf78649807b93684f3d431bfa34ee28.html">cutlass::arch::Mma< gemm::GemmShape< 2, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_011_00_011_01_4_00_011_00_01half__t_8cf78649807b93684f3d431bfa34ee28.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_012_00_011_01_4_00_011_00_01half__t_f3dc2e59f857ada163d1e0781ea8f391.html">cutlass::arch::Mma< gemm::GemmShape< 1, 2, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, layout::RowMajor, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_012_00_011_01_4_00_011_00_01half__t_f3dc2e59f857ada163d1e0781ea8f391.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_012_00_011_01_4_00_011_00_01half__t_ccde11d1bbbdab3702772ce44eb9729a.html">cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::ColumnMajor, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_012_00_011_01_4_00_011_00_01half__t_ccde11d1bbbdab3702772ce44eb9729a.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_012_00_011_01_4_00_011_00_01half__t_c07cc6439298fa5486a719e577be2538.html">cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_012_00_012_00_011_01_4_00_011_00_01half__t_c07cc6439298fa5486a719e577be2538.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/arch_2mma__sm60_8h__dep__incl.md5
Normal file
1
docs/arch_2mma__sm60_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
ba69b14e3936946092854211499ae9fa
|
||||
1
docs/arch_2mma__sm60_8h__incl.md5
Normal file
1
docs/arch_2mma__sm60_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
e820099c55f2397639bb210d76ec4c05
|
||||
123
docs/arch_2mma__sm60_8h_source.html
Normal file
123
docs/arch_2mma__sm60_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
149
docs/arch_2mma__sm61_8h.html
Normal file
149
docs/arch_2mma__sm61_8h.html
Normal file
@ -0,0 +1,149 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: mma_sm61.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">arch/mma_sm61.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Matrix multiply.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="layout_2matrix_8h_source.html">cutlass/layout/matrix.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for arch/mma_sm61.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm61_8h__incl.png" border="0" usemap="#mma__sm61_8h" alt=""/></div>
|
||||
<map name="mma__sm61_8h" id="mma__sm61_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_2mma__sm61_8h__dep__incl.png" border="0" usemap="#mma__sm61_8hdep" alt=""/></div>
|
||||
<map name="mma__sm61_8hdep" id="mma__sm61_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="arch_2mma__sm61_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_014_01_4_00_011_00_01int8__t_a1ef6624fc8c10126f17f4ee88283d72.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 4 >, 1, int8_t, LayoutA, int8_t, LayoutB, int, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_014_01_4_00_011_00_01int8__t_a1ef6624fc8c10126f17f4ee88283d72.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_012_01_4_00_011_00_01int16__t8c4bac365710598317a69c489f7239db.html">cutlass::arch::Mma< gemm::GemmShape< 1, 1, 2 >, 1, int16_t, layout::RowMajor, int16_t, layout::ColumnMajor, int, LayoutC, OpMultiplyAdd ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Matrix multiply-add operation. <a href="structcutlass_1_1arch_1_1Mma_3_01gemm_1_1GemmShape_3_011_00_011_00_012_01_4_00_011_00_01int16__t8c4bac365710598317a69c489f7239db.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/arch_2mma__sm61_8h__dep__incl.md5
Normal file
1
docs/arch_2mma__sm61_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
1faaf1631d5f0e44d6cc6c7121e6972e
|
||||
1
docs/arch_2mma__sm61_8h__incl.md5
Normal file
1
docs/arch_2mma__sm61_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
8cce8aef2d98c4082d68734b538253c7
|
||||
119
docs/arch_2mma__sm61_8h_source.html
Normal file
119
docs/arch_2mma__sm61_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
147
docs/arch_8h.html
Normal file
147
docs/arch_8h.html
Normal file
@ -0,0 +1,147 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: arch.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_048c1df36ab9c2efbb0733edba6291c9.html">arch</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">arch.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Defines tags for architecture-specific configurations.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="arch_8h__dep__incl.png" border="0" usemap="#arch_8hdep" alt=""/></div>
|
||||
<map name="arch_8hdep" id="arch_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="arch_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm50.html">cutlass::arch::Sm50</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm60.html">cutlass::arch::Sm60</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm61.html">cutlass::arch::Sm61</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm70.html">cutlass::arch::Sm70</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm72.html">cutlass::arch::Sm72</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1arch_1_1Sm75.html">cutlass::arch::Sm75</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1arch"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1arch.html">cutlass::arch</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/arch_8h__dep__incl.md5
Normal file
1
docs/arch_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
9ea32ea41ab87776449ab855965480b3
|
||||
117
docs/arch_8h_source.html
Normal file
117
docs/arch_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
167
docs/array_8h.html
Normal file
167
docs/array_8h.html
Normal file
@ -0,0 +1,167 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: array.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> |
|
||||
<a href="#func-members">Functions</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">array.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="cutlass_8h_source.html">cutlass/cutlass.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="numeric__types_8h_source.html">cutlass/numeric_types.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="array__subbyte_8h_source.html">cutlass/array_subbyte.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for array.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="array_8h__incl.png" border="0" usemap="#array_8h" alt=""/></div>
|
||||
<map name="array_8h" id="array_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="array_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1sizeof__bits_3_01Array_3_01T_00_01N_00_01RegisterSized_01_4_01_4.html">cutlass::sizeof_bits< Array< T, N, RegisterSized > ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Statically sized array for any data type. <a href="structcutlass_1_1sizeof__bits_3_01Array_3_01T_00_01N_00_01RegisterSized_01_4_01_4.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4.html">cutlass::Array< T, N, true ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Statically sized array for any data type. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1iterator.html">cutlass::Array< T, N, true >::iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1const__iterator.html">cutlass::Array< T, N, true >::const_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional constant iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1const__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1reverse__iterator.html">cutlass::Array< T, N, true >::reverse_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1reverse__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1const__reverse__iterator.html">cutlass::Array< T, N, true >::const_reverse_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional constant iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01true_01_4_1_1const__reverse__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1AlignedArray.html">cutlass::AlignedArray< T, N, Alignment ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Aligned array type. <a href="classcutlass_1_1AlignedArray.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a935aabfdc47cf03f87c67bb22533f97f"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="platform_8h.html#a72f0657181cca64b44eb186b707eb380">constexpr</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html#a935aabfdc47cf03f87c67bb22533f97f">cutlass::ispow2</a> (unsigned x)</td></tr>
|
||||
<tr class="memdesc:a935aabfdc47cf03f87c67bb22533f97f"><td class="mdescLeft"> </td><td class="mdescRight">Returns true if the argument is a power of 2. <a href="namespacecutlass.html#a935aabfdc47cf03f87c67bb22533f97f">More...</a><br /></td></tr>
|
||||
<tr class="separator:a935aabfdc47cf03f87c67bb22533f97f"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:ac16d8caf23537912eb02123c4bdacd14"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="platform_8h.html#a72f0657181cca64b44eb186b707eb380">constexpr</a> unsigned </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html#ac16d8caf23537912eb02123c4bdacd14">cutlass::floor_pow_2</a> (unsigned x)</td></tr>
|
||||
<tr class="memdesc:ac16d8caf23537912eb02123c4bdacd14"><td class="mdescLeft"> </td><td class="mdescRight">Returns the largest power of two not greater than the argument. <a href="namespacecutlass.html#ac16d8caf23537912eb02123c4bdacd14">More...</a><br /></td></tr>
|
||||
<tr class="separator:ac16d8caf23537912eb02123c4bdacd14"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/array_8h__incl.md5
Normal file
1
docs/array_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
90c159bd7ad938ad2d6e263ea8402fe7
|
||||
194
docs/array_8h_source.html
Normal file
194
docs/array_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
164
docs/array__subbyte_8h.html
Normal file
164
docs/array__subbyte_8h.html
Normal file
@ -0,0 +1,164 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: array_subbyte.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">array_subbyte.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include "<a class="el" href="cutlass_8h_source.html">cutlass/cutlass.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="array_8h_source.html">cutlass/array.h</a>"</code><br />
|
||||
<code>#include "<a class="el" href="platform_8h_source.html">cutlass/platform/platform.h</a>"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for array_subbyte.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="array__subbyte_8h__incl.png" border="0" usemap="#array__subbyte_8h" alt=""/></div>
|
||||
<map name="array__subbyte_8h" id="array__subbyte_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="array__subbyte_8h__dep__incl.png" border="0" usemap="#array__subbyte_8hdep" alt=""/></div>
|
||||
<map name="array__subbyte_8hdep" id="array__subbyte_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="array__subbyte_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4.html">cutlass::Array< T, N, false ></a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Statically sized array for any data type. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1reference.html">cutlass::Array< T, N, false >::reference</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Reference object inserts or extracts sub-byte items. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1reference.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__reference.html">cutlass::Array< T, N, false >::const_reference</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Reference object extracts sub-byte items. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__reference.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1iterator.html">cutlass::Array< T, N, false >::iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__iterator.html">cutlass::Array< T, N, false >::const_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional constant iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1reverse__iterator.html">cutlass::Array< T, N, false >::reverse_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1reverse__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">class  </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__reverse__iterator.html">cutlass::Array< T, N, false >::const_reverse_iterator</a></td></tr>
|
||||
<tr class="memdesc:"><td class="mdescLeft"> </td><td class="mdescRight">Bidirectional constant iterator over elements. <a href="classcutlass_1_1Array_3_01T_00_01N_00_01false_01_4_1_1const__reverse__iterator.html#details">More...</a><br /></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/array__subbyte_8h__dep__incl.md5
Normal file
1
docs/array__subbyte_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
7c0288c037b6ea169ec7a3aa1015a4d4
|
||||
1
docs/array__subbyte_8h__incl.md5
Normal file
1
docs/array__subbyte_8h__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
36310516438810c2a8ba31a7816cd1de
|
||||
181
docs/array__subbyte_8h_source.html
Normal file
181
docs/array__subbyte_8h_source.html
Normal file
File diff suppressed because one or more lines are too long
155
docs/batched__reduction_8h.html
Normal file
155
docs/batched__reduction_8h.html
Normal file
@ -0,0 +1,155 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.11"/>
|
||||
<title>CUTLASS: batched_reduction.h File Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/javascript">
|
||||
$(document).ready(function() { init_search(); });
|
||||
</script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectlogo"><img alt="Logo" src="cutlass-logo-small.png"/></td>
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">CUTLASS
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.11 -->
|
||||
<script type="text/javascript">
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
</script>
|
||||
<div id="navrow1" class="tabs">
|
||||
<ul class="tablist">
|
||||
<li><a href="index.html"><span>Main Page</span></a></li>
|
||||
<li><a href="modules.html"><span>Modules</span></a></li>
|
||||
<li><a href="namespaces.html"><span>Namespaces</span></a></li>
|
||||
<li><a href="annotated.html"><span>Classes</span></a></li>
|
||||
<li class="current"><a href="files.html"><span>Files</span></a></li>
|
||||
<li>
|
||||
<div id="MSearchBox" class="MSearchBoxInactive">
|
||||
<span class="left">
|
||||
<img id="MSearchSelect" src="search/mag_sel.png"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
alt=""/>
|
||||
<input type="text" id="MSearchField" value="Search" accesskey="S"
|
||||
onfocus="searchBox.OnSearchFieldFocus(true)"
|
||||
onblur="searchBox.OnSearchFieldFocus(false)"
|
||||
onkeyup="searchBox.OnSearchFieldChange(event)"/>
|
||||
</span><span class="right">
|
||||
<a id="MSearchClose" href="javascript:searchBox.CloseResultsWindow()"><img id="MSearchCloseImg" border="0" src="search/close.png" alt=""/></a>
|
||||
</span>
|
||||
</div>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div id="navrow2" class="tabs2">
|
||||
<ul class="tablist">
|
||||
<li><a href="files.html"><span>File List</span></a></li>
|
||||
<li><a href="globals.html"><span>File Members</span></a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="dir_d44c64559bbebec7f509842c48db8b23.html">include</a></li><li class="navelem"><a class="el" href="dir_6baf2bb612a2f0daa69af3101ede80a1.html">cutlass</a></li><li class="navelem"><a class="el" href="dir_ac488927e63b76ba9cb3ad9c317bbde9.html">reduction</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#nested-classes">Classes</a> |
|
||||
<a href="#namespaces">Namespaces</a> |
|
||||
<a href="#func-members">Functions</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">batched_reduction.h File Reference</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>Implements a software-pipelined efficient batched reduction. D = alpha * Reduction(A) + beta * C.
|
||||
<a href="#details">More...</a></p>
|
||||
<div class="textblock"><code>#include <cuda.h></code><br />
|
||||
<code>#include "<a class="el" href="coord_8h_source.html">cutlass/coord.h</a>"</code><br />
|
||||
<code>#include "cutlass/util/platform.h"</code><br />
|
||||
<code>#include "cutlass/fragment.h"</code><br />
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
Include dependency graph for batched_reduction.h:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="batched__reduction_8h__incl.png" border="0" usemap="#batched__reduction_8h" alt=""/></div>
|
||||
<map name="batched__reduction_8h" id="batched__reduction_8h">
|
||||
</map>
|
||||
</div>
|
||||
</div><div class="textblock"><div class="dynheader">
|
||||
This graph shows which files directly or indirectly include this file:</div>
|
||||
<div class="dyncontent">
|
||||
<div class="center"><img src="batched__reduction_8h__dep__incl.png" border="0" usemap="#batched__reduction_8hdep" alt=""/></div>
|
||||
<map name="batched__reduction_8hdep" id="batched__reduction_8hdep">
|
||||
</map>
|
||||
</div>
|
||||
</div>
|
||||
<p><a href="batched__reduction_8h_source.html">Go to the source code of this file.</a></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="nested-classes"></a>
|
||||
Classes</h2></td></tr>
|
||||
<tr class="memitem:"><td class="memItemLeft" align="right" valign="top">struct  </td><td class="memItemRight" valign="bottom"><a class="el" href="structcutlass_1_1reduction_1_1BatchedReduction.html">cutlass::reduction::BatchedReduction< BatchedReductionTraits_ ></a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="namespaces"></a>
|
||||
Namespaces</h2></td></tr>
|
||||
<tr class="memitem:namespacecutlass"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass.html">cutlass</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:namespacecutlass_1_1reduction"><td class="memItemLeft" align="right" valign="top">  </td><td class="memItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1reduction.html">cutlass::reduction</a></td></tr>
|
||||
<tr class="separator:"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table><table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="func-members"></a>
|
||||
Functions</h2></td></tr>
|
||||
<tr class="memitem:a9665e8f438a7b290d6e2eb640d93045f"><td class="memTemplParams" colspan="2">template<typename batched_reduction_ > </td></tr>
|
||||
<tr class="memitem:a9665e8f438a7b290d6e2eb640d93045f"><td class="memTemplItemLeft" align="right" valign="top">__global__ </td><td class="memTemplItemRight" valign="bottom"><a class="el" href="namespacecutlass_1_1reduction.html#a9665e8f438a7b290d6e2eb640d93045f">cutlass::reduction::__launch_bounds__</a> (batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_</td></tr>
|
||||
<tr class="separator:a9665e8f438a7b290d6e2eb640d93045f"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.11
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
1
docs/batched__reduction_8h__dep__incl.md5
Normal file
1
docs/batched__reduction_8h__dep__incl.md5
Normal file
@ -0,0 +1 @@
|
||||
2bce650f452329d669d303788cc619c8
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user