Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cf0301e00f | |||
| b9bb0d1a49 | |||
| e1c4ba501b | |||
| c566e83e6d | |||
| 374882be53 | |||
| 2c496c3e9e | |||
| 9fd55460c6 | |||
| 480732c2e8 | |||
| 68aaee8773 | |||
| acb90e962a | |||
| 96bc3f227f | |||
| 25ff282403 | |||
| 9d5726a568 | |||
| 6f0d271d8d | |||
| 923dfb42ce | |||
| 6f6f269a0a | |||
| 2028ebe120 | |||
| 901287175f |
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "tools/external/googletest"]
|
||||
path = tools/external/googletest
|
||||
url = https://github.com/google/googletest.git
|
||||
26
CMake/bin2hex.cmake
Normal file
26
CMake/bin2hex.cmake
Normal file
@ -0,0 +1,26 @@
|
||||
# A small utility function which generates a C-header from an input file
|
||||
function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
|
||||
FILE(READ "${FILENAME}" HEX_INPUT HEX)
|
||||
if (${ZERO_TERMINATED})
|
||||
string(APPEND HEX_INPUT "00")
|
||||
endif()
|
||||
|
||||
string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
|
||||
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
|
||||
|
||||
set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")
|
||||
|
||||
set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
message("Create header file for ${FILE_IN}")
|
||||
message("Create header file for ${FILE_OUT}")
|
||||
file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
|
||||
|
||||
set(RESULT "#pragma once\n")
|
||||
string(APPEND RESULT "namespace cutlass {\n")
|
||||
string(APPEND RESULT "namespace nvrtc {\n")
|
||||
string(APPEND RESULT "${OUTPUT_STRING}")
|
||||
string(APPEND RESULT "} // namespace nvrtc\n")
|
||||
string(APPEND RESULT "} // namespace cutlass\n")
|
||||
file(WRITE "${FILE_OUT}" "${RESULT}")
|
||||
182
CMakeLists.txt
Normal file
182
CMakeLists.txt
Normal file
@ -0,0 +1,182 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
cmake_minimum_required(VERSION 3.3.0)
|
||||
|
||||
set(CUTLASS_LANGUAGES CXX)
|
||||
|
||||
# CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it!
|
||||
if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0")
|
||||
list(APPEND CUTLASS_LANGUAGES CUDA)
|
||||
set(CUTLASS_NATIVE_CUDA TRUE)
|
||||
|
||||
macro(cutlass_add_executable)
|
||||
add_executable(${ARGN})
|
||||
endmacro()
|
||||
else()
|
||||
# FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
|
||||
# For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
|
||||
if (WIN32 AND MSVC_VERSION GREATER 1800)
|
||||
message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
|
||||
endif()
|
||||
|
||||
# Fall back to the FindCUDA version to create an executable with CUDA files
|
||||
macro(cutlass_add_executable)
|
||||
cuda_add_executable(${ARGN})
|
||||
endmacro()
|
||||
endif()
|
||||
|
||||
project(CUTLASS ${CUTLASS_LANGUAGES})
|
||||
|
||||
# check if the configuration is supported
|
||||
if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
|
||||
message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
|
||||
endif()
|
||||
|
||||
find_package(CUDA)
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance
|
||||
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
|
||||
# We do support Debug or Release builds
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release")
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
# On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
|
||||
set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
endif(WIN32)
|
||||
|
||||
# Configure CUDA options
|
||||
set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS})
|
||||
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
|
||||
endforeach()
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_KEEP)
|
||||
string(APPEND NVCC_FLAGS " -keep")
|
||||
endif()
|
||||
|
||||
if (WIN32 AND CUTLASS_NATIVE_CUDA)
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -lineinfo")
|
||||
else()
|
||||
string(APPEND NVCC_FLAGS " -lineinfo")
|
||||
endif()
|
||||
|
||||
if (UNIX)
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
|
||||
endif()
|
||||
|
||||
string(APPEND NVCC_FLAGS_DEBUG " -g")
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -O3")
|
||||
|
||||
# define NDEBUG for release mode to disable assertions
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
|
||||
|
||||
if (CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
|
||||
else()
|
||||
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
|
||||
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
|
||||
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
|
||||
endif()
|
||||
|
||||
#
|
||||
# The following items should eventually be pushed into cutlass/CMakeLists.txt
|
||||
#
|
||||
|
||||
# GLOB for CUTLASS header files. Should we use a static list instead?
|
||||
file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
|
||||
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
|
||||
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
|
||||
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
|
||||
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
|
||||
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
|
||||
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
|
||||
source_group("cutlass" FILES ${CUTLASS_CORE})
|
||||
|
||||
add_library(CUTLASS INTERFACE)
|
||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
target_sources(CUTLASS INTERFACE
|
||||
${CUTLASS_GEMM}
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
)
|
||||
|
||||
target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# Create a custom target to ensure that the CUTLASS sources are visible in an IDE
|
||||
add_custom_target(cutlass_ide SOURCES
|
||||
${CUTLASS_GEMM}
|
||||
${CUTLASS_UTIL}
|
||||
${CUTLASS_DEVICE}
|
||||
${CUTLASS_CORE}
|
||||
)
|
||||
# Doxygen is available. Generate documentation
|
||||
if (DOXYGEN_FOUND)
|
||||
# DOT is available. Enable graph generation in the documentation
|
||||
if (DOXYGEN_DOT_EXECUTABLE)
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
|
||||
else()
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_DOXYGEN_DOT)
|
||||
set(HAVE_DOT "YES")
|
||||
else()
|
||||
set(HAVE_DOT "NO")
|
||||
endif()
|
||||
|
||||
# Add custom target for Doxygen.
|
||||
add_custom_target(cutlass_docs ${CMAKE_COMMAND} -E env
|
||||
"DOT_PATH=${DOXYGEN_DOT_EXECUTABLE}"
|
||||
"HAVE_DOT=${HAVE_DOT}"
|
||||
${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
VERBATIM
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
#add_subdirectory(examples/gemm)
|
||||
add_subdirectory(tools)
|
||||
15
Doxyfile
15
Doxyfile
@ -58,7 +58,7 @@ PROJECT_LOGO =
|
||||
# entered, it will be relative to the location where doxygen was started. If
|
||||
# left blank the current directory will be used.
|
||||
|
||||
OUTPUT_DIRECTORY = doxygen
|
||||
OUTPUT_DIRECTORY = docs
|
||||
|
||||
# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
|
||||
# directories (in 2 levels) under the output directory of each output format and
|
||||
@ -218,7 +218,8 @@ TAB_SIZE = 4
|
||||
# "Side Effects:". You can put \n's in the value part of an alias to insert
|
||||
# newlines.
|
||||
|
||||
ALIASES =
|
||||
#ALIASES += "concept{1}=@ingroup \1\n@par Implemented concepts:\n@ref \1"
|
||||
ALIASES += "concept{1}=@ingroup \1"
|
||||
|
||||
# This tag can be used to specify a number of word-keyword mappings (TCL only).
|
||||
# A mapping has the form "name=value". For example adding "class=itcl::class"
|
||||
@ -396,7 +397,7 @@ LOOKUP_CACHE_SIZE = 0
|
||||
# normally produced when WARNINGS is set to YES.
|
||||
# The default value is: NO.
|
||||
|
||||
EXTRACT_ALL = NO
|
||||
EXTRACT_ALL = YES
|
||||
|
||||
# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will
|
||||
# be included in the documentation.
|
||||
@ -733,7 +734,7 @@ WARN_LOGFILE =
|
||||
# spaces.
|
||||
# Note: If this tag is empty the current directory is searched.
|
||||
|
||||
INPUT = cutlass cutlass/gemm cutlass/util
|
||||
INPUT = cutlass
|
||||
|
||||
# This tag can be used to specify the character encoding of the source files
|
||||
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
|
||||
@ -759,7 +760,7 @@ FILE_PATTERNS =
|
||||
# be searched for input files as well.
|
||||
# The default value is: NO.
|
||||
|
||||
RECURSIVE = NO
|
||||
RECURSIVE = YES
|
||||
|
||||
# The EXCLUDE tag can be used to specify files and/or directories that should be
|
||||
# excluded from the INPUT source files. This way you can easily exclude a
|
||||
@ -2032,7 +2033,7 @@ HIDE_UNDOC_RELATIONS = YES
|
||||
# set to NO
|
||||
# The default value is: NO.
|
||||
|
||||
HAVE_DOT = NO
|
||||
HAVE_DOT = $(HAVE_DOT)
|
||||
|
||||
# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
|
||||
# to run in parallel. When set to 0 doxygen will base this on the number of
|
||||
@ -2204,7 +2205,7 @@ INTERACTIVE_SVG = NO
|
||||
# found. If left blank, it is assumed the dot tool can be found in the path.
|
||||
# This tag requires that the tag HAVE_DOT is set to YES.
|
||||
|
||||
DOT_PATH =
|
||||
DOT_PATH = $(DOT_PATH)
|
||||
|
||||
# The DOTFILE_DIRS tag can be used to specify one or more directories that
|
||||
# contain dot files that are included in the documentation (see the \dotfile
|
||||
|
||||
285
README.md
285
README.md
@ -1,106 +1,221 @@
|
||||

|
||||

|
||||
|
||||
# Introduction
|
||||
# CUTLASS 1.0
|
||||
|
||||
CUTLASS is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
_CUTLASS 1.0.1 - June 2018_
|
||||
|
||||
CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
reusable, modular software components abstracted by C++ template classes. These
|
||||
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
|
||||
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
|
||||
resulting flexibility simplifies their use as building blocks within custom kernels
|
||||
and applications.
|
||||
|
||||
To support a wide variety of applications, CUTLASS provides extensive support for
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for 8-bit integer, half-precision floating
|
||||
point (FP16), single-precision floating point (FP32), and double-precision floating
|
||||
mixed-precision computations, providing specialized data-movement and
|
||||
multiply-accumulate abstractions for 8-bit integer, half-precision floating
|
||||
point (FP16), single-precision floating point (FP32), and double-precision floating
|
||||
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond.
|
||||
|
||||
For more exposition, see our Parallel Forall blog post [CUTLASS: Fast Linear Algebra
|
||||
in CUDA C++](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda).
|
||||
CUTLASS 1.0 has changed substantially from our preview release described in
|
||||
the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda)
|
||||
post. We have decomposed the structure of the GEMM computation into deeper, structured
|
||||
primitives for loading data, computing predicate masks, streaming data at each level of
|
||||
the GEMM hierarchy, and updating the output matrix.
|
||||
|
||||
CUTLASS 1.0 is described in the [Doxygen documentation](https://nvidia.github.io/cutlass)
|
||||
and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
# Performance
|
||||
|
||||
<p align="center"><img src=/media/cutlass-performance-plot.png></p>
|
||||
<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
|
||||
when compiled with CUDA 9.2.
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
|-----------------|----------|
|
||||
| Windows 10 | Microsoft Visual Studio 2015|
|
||||
| | Microsoft Visual Studio 2017|
|
||||
| Ubuntu 14.04 | GCC 4.8.2 |
|
||||
| Ubuntu 16.04 | GCC 5.4.0 |
|
||||
|
||||
|
||||
CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
|
||||
any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
|
||||
|
||||
|**GPU**|
|
||||
|---|
|
||||
|NVIDIA GeForce 1080|
|
||||
|NVIDIA TitanXP|
|
||||
|NVIDIA Tesla P100|
|
||||
|NVIDIA Tesla V100|
|
||||
|NVIDIA TitanV|
|
||||
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
CUTLASS is a header-only template library and does not need to be built to be used by other
|
||||
projects. However, we distribute extensive unit tests and utility programs to demonstrate
|
||||
CUTLASS. These instructions are for building those test programs.
|
||||
|
||||
CUTLASS's unit tests depend on Google Test which exists as a git submodule. You can fetch
|
||||
submodules as follows.
|
||||
|
||||
```
|
||||
$ git submodule update --init --recursive
|
||||
```
|
||||
|
||||
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
Create a build directory within the CUTLASS project, then run CMake once.
|
||||
|
||||
```
|
||||
$ mkdir build && cd build
|
||||
$ cmake ..
|
||||
```
|
||||
|
||||
Compile the CUTLASS project by running Make. Include the -j argument to compile sources in
|
||||
parallel and speed up the build process.
|
||||
|
||||
```
|
||||
$ make -j12
|
||||
...
|
||||
$
|
||||
```
|
||||
|
||||
Verify CUTLASS has been built correctly by running the unit tests from the build/ directory.
|
||||
|
||||
```
|
||||
$ ./tools/test/unit/cutlass_unit_test
|
||||
...
|
||||
...
|
||||
...
|
||||
[----------] Global test environment tear-down
|
||||
[==========] 481 tests from 24 test cases ran. (5954 ms total)
|
||||
[ PASSED ] 481 tests.
|
||||
```
|
||||
|
||||
All tests should pass, though the exact number of tests may vary over time.
|
||||
|
||||
CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Tesla V100 GPU
|
||||
when compiled with CUDA 9.0.
|
||||
|
||||
# Project Structure
|
||||
|
||||
CUTLASS is arranged as a header-only library with several example test programs
|
||||
that demonstrate instantiating a GEMM task within a CUDA kernel. Comments inline
|
||||
with the source explain the individual components.
|
||||
that demonstrate instantiating a GEMM task within a CUDA kernel. The Doxygen documentation
|
||||
provides a complete list of files, classes, and template concepts defined in the CUTLASS
|
||||
project. A brief summary is described below.
|
||||
|
||||
The repository is organized in the following arrangement.
|
||||
|
||||
cutlass/ Root of header-only source library for matrix multiply
|
||||
gemm/ Implementation of GEMM __device__ code and supporting components
|
||||
util/ Utility components for CUDA device-side CUDA development
|
||||
|
||||
A test program is provided to illustrate the use of CUTLASS. This is implemented
|
||||
in the following directory.
|
||||
|
||||
cutlass_test Root of test programs depicting CUTLASS kernels
|
||||
util/ Utilities
|
||||
gemm.cu Simple example calling CUTLASS and CUBLAS GEMM kernels
|
||||
Makefile Build script for test programs
|
||||
|
||||
|
||||
# Makefile usage
|
||||
|
||||
There are different sample targets for different GEMM data types and
|
||||
transposititions. Be sure to specify your target architecture.
|
||||
|
||||
make <sgemm|dgemm|hgemm|igemm|wgemm> sm=<60|61|70> \
|
||||
[transpose=<nn|nt|tn|tt>] [verbose=<0|1>] [keep=<0|1>]
|
||||
|
||||
|
||||
# Program usage
|
||||
|
||||
Program usage:
|
||||
|
||||
<s|d|h|i|w>gemm_<nn|nt|tn|tt>
|
||||
[--help]
|
||||
[--schmoo=<#schmoo-samples> || --m=<height> --n=<width> --k=<depth>]
|
||||
[--i=<timing iterations>]
|
||||
[--device=<device-id>]
|
||||
[--alpha=<alpha> --beta=<beta>]
|
||||
|
||||
|
||||
# Open Source License
|
||||
|
||||
CUTLASS is released by NVIDIA Corporation under the "New BSD" open-source license:
|
||||
The CUTLASS library is defined in the cutlass/ directory and consists of CUDA C++ template
|
||||
classes and other definitions for implementing efficient GPU GEMM kernels. A set of core
|
||||
classes and templates define basic primitives that are then applied to compute GEMM via
|
||||
templates in the cutlass/gemm directory.
|
||||
|
||||
```
|
||||
Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the
|
||||
names of its contributors may be used to endorse or promote products
|
||||
derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
cutlass/
|
||||
gemm/
|
||||
util/
|
||||
<core API components>
|
||||
```
|
||||
|
||||
Several tools and test programs are also distributed with the CUTLASS library. They are
|
||||
contained in the following directories.
|
||||
|
||||
```
|
||||
tools/
|
||||
test/
|
||||
unit/
|
||||
core/
|
||||
gemm/
|
||||
perf/
|
||||
util/
|
||||
<utilities>
|
||||
```
|
||||
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
Its usage is shown below.
|
||||
|
||||
Program usage:
|
||||
|
||||
```
|
||||
cutlass_perf_test [options]
|
||||
|
||||
--help
|
||||
--append=<true|false*> If true, appends output to existing CSV file. If false, overwrites.
|
||||
--alpha=<alpha> Value for alpha to be used in GEMM experiments
|
||||
--beta=<beta> Value for beta to be used in GEMM experiments
|
||||
--dist=<distribution> Describes the random distribution of each of the input matrix operands.
|
||||
--execution_mode=<mode> Specifies execution mode: profile, verify, single
|
||||
--output=<filename.csv> Writes summary of profiling to specified .csv file
|
||||
--iterations=<timing iterations> maximum number of iterations to execute when profiling
|
||||
--m=<height>[:max height[:step]] Height of GEMM problem (number of rows of C). May specify a range with optional step size.
|
||||
--n=<width>[:max width[:step]] Width of GEMM problem (number of columns of C). May specify a range with optional step size.
|
||||
--k=<depth>[:max depth[:step]] Size of inner dimension of A and B. May specify a range with optional step size.
|
||||
--kernels=<{s|d|h|i|wmma_}gemm_{nn,nt,tn,tt}> Select GEMM datatype and layout to use for tests
|
||||
--peak=<bool> If true, only reports peak performance per kernel after profiling specified problem space.
|
||||
--save_workspace={*never,incorrect,always} Specifies when to save the GEMM inputs and results to the filesystem.
|
||||
--seed=<seed> Random seed used by the random number generator in initializing input matrices.
|
||||
--tags=<column:tag,...> Inserts leading columns in output table and uniform values for each column.
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
# Runs one problem size for all kernels
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024
|
||||
|
||||
# Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
|
||||
$ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
|
||||
```
|
||||
|
||||
# About
|
||||
|
||||
CUTLASS is released by NVIDIA Corporation as Open Source software under the
|
||||
3-clause "New" BSD license.
|
||||
|
||||
|
||||
# Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
|
||||
47
changelog.md
Normal file
47
changelog.md
Normal file
@ -0,0 +1,47 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
|
||||
|
||||
* Intra-threadblock reduction added for small threadblock tile sizes
|
||||
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
|
||||
* igemm_32x32x128
|
||||
* GEMM _K_ residue handled during prologue prior to mainloop
|
||||
* Replaced Google Test copy with submodule. Use `git submodule init`
|
||||
|
||||
## [1.0.0](https://github.com/NVIDIA/cutlass/commit/2028ebe120aab22bfd0b2baf8902d4c9627eb33f) (2018-05-16)
|
||||
|
||||
* Substantial rewrite to accommodate new architecture
|
||||
* Kernels: SGEMM, DGEMM, IGEMM, HGEMM, WMMA GEMM
|
||||
* Unit and performance tests
|
||||
|
||||
## [0.0.1](https://github.com/NVIDIA/cutlass/commit/d08ba8ac46e2fa3f745e070c390182edb56b2e91) (2017-12-04)
|
||||
|
||||
* Initial release
|
||||
|
||||
|
||||
## Copyright
|
||||
|
||||
Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
```
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
provided that the following conditions are met:
|
||||
* Redistributions of source code must retain the above copyright notice, this list of
|
||||
conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
conditions and the following disclaimer in the documentation and/or other materials
|
||||
provided with the distribution.
|
||||
* Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
to endorse or promote products derived from this software without specific prior written
|
||||
permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
```
|
||||
|
||||
17
clang-format.sh
Executable file
17
clang-format.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
function formatFiles {
|
||||
for f in `find "$1" -type f -name "*.$2"` ; do
|
||||
COMMAND="clang-format -i $f"
|
||||
echo $COMMAND
|
||||
$COMMAND
|
||||
done
|
||||
}
|
||||
|
||||
formatFiles "cutlass" "h"
|
||||
formatFiles "tools/test" "h"
|
||||
formatFiles "tools/test" "cpp"
|
||||
formatFiles "tools/util" "h"
|
||||
|
||||
181
common.mk
181
common.mk
@ -1,181 +0,0 @@
|
||||
#/******************************************************************************
|
||||
# * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
# *
|
||||
# * Redistribution and use in source and binary forms, with or without
|
||||
# * modification, are permitted provided that the following conditions are met:
|
||||
# * * Redistributions of source code must retain the above copyright
|
||||
# * notice, this list of conditions and the following disclaimer.
|
||||
# * * Redistributions in binary form must reproduce the above copyright
|
||||
# * notice, this list of conditions and the following disclaimer in the
|
||||
# * documentation and/or other materials provided with the distribution.
|
||||
# * * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# * names of its contributors may be used to endorse or promote products
|
||||
# * derived from this software without specific prior written permission.
|
||||
# *
|
||||
# * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
# *
|
||||
#******************************************************************************/
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Commandline Options
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# sm=<XX,...> Compute-capability to compile for, e.g., "sm=200,300,350" (SM2.0 by default).
|
||||
|
||||
COMMA := ,
|
||||
ifdef sm
|
||||
SM_ARCH := $(subst $(COMMA),-,$(sm))
|
||||
else
|
||||
$(error Please specify SM architecture makefile argument: "sm=XX")
|
||||
endif
|
||||
|
||||
ifeq (70, $(findstring 70, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_70,code=\"sm_70,compute_70\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_70
|
||||
endif
|
||||
ifeq (62, $(findstring 62, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_62,code=\"sm_62,compute_62\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_62
|
||||
endif
|
||||
ifeq (61, $(findstring 61, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_61,code=\"sm_61,compute_61\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_61
|
||||
endif
|
||||
ifeq (60, $(findstring 60, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_60,code=\"sm_60,compute_60\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_60
|
||||
endif
|
||||
ifeq (52, $(findstring 52, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_52,code=\"sm_52,compute_52\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_52
|
||||
endif
|
||||
ifeq (37, $(findstring 37, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_37,code=\"sm_37,compute_37\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_37
|
||||
endif
|
||||
ifeq (35, $(findstring 35, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_35,code=\"sm_35,compute_35\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_35
|
||||
endif
|
||||
ifeq (30, $(findstring 30, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_30,code=\"sm_30,compute_30\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_30
|
||||
endif
|
||||
ifeq (21, $(findstring 21, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_20,code=\"sm_21,compute_20\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_21
|
||||
endif
|
||||
ifeq (20, $(findstring 20, $(SM_ARCH)))
|
||||
SM_TARGETS += -gencode=arch=compute_20,code=\"sm_20,compute_20\"
|
||||
CLANG_SM_TARGETS += --cuda-gpu-arch=sm_20
|
||||
endif
|
||||
|
||||
|
||||
# [verbose=<0|1>] Verbose toolchain output from nvcc option
|
||||
ifeq ($(verbose), 1)
|
||||
NVCCFLAGS += -v
|
||||
CLANG_CFLAGS += -v
|
||||
endif
|
||||
|
||||
|
||||
# [keep=<0|1>] Keep intermediate compilation artifacts option
|
||||
ifeq ($(keep), 1)
|
||||
NVCCFLAGS += -keep
|
||||
CLANG_CFLAGS += --save-temps
|
||||
endif
|
||||
|
||||
|
||||
# [debug=<0|1>] Generate debug mode code
|
||||
ifeq ($(debug), 1)
|
||||
NVCCFLAGS += -G
|
||||
CLANG_CFLAGS += --cuda-noopt-device-debug
|
||||
endif
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Compiler and compilation platform
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
BASE_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
|
||||
|
||||
NVCC := "$(shell which nvcc)"
|
||||
ifdef nvccver
|
||||
NVCC_VERSION := $(nvccver)
|
||||
else
|
||||
NVCC_VERSION := $(strip $(shell nvcc --version | grep release | sed 's/.*release //' | sed 's/,.*//'))
|
||||
endif
|
||||
|
||||
# Detect OS
|
||||
OSUPPER := $(shell uname -s 2>/dev/null | tr [:lower:] [:upper:])
|
||||
|
||||
# Default flags: verbose kernel properties (regs, smem, cmem, etc.); runtimes for compilation phases
|
||||
NVCCFLAGS += -O3 -Xptxas -v
|
||||
CLANG_CFLAGS += -O3 -Xcuda-ptxas -v
|
||||
ifeq (WIN_NT, $(findstring WIN_NT, $(OSUPPER)))
|
||||
# For MSVC
|
||||
|
||||
# Enable more warnings and treat as errors
|
||||
NVCCFLAGS += -Xcompiler /W3 -Xcompiler /WX
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
NVCCFLAGS += -Xcompiler /fp:strict
|
||||
|
||||
# Compiler
|
||||
CC := cl
|
||||
|
||||
# Multithreaded runtime
|
||||
NVCCFLAGS += -Xcompiler /MT
|
||||
|
||||
CUDART_CYG := "$(shell dirname $(NVCC))/../lib/x64/cudart.lib"
|
||||
CUDART := "$(shell cygpath -w $(CUDART_CYG))"
|
||||
|
||||
else
|
||||
# For g++
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
#NVCCFLAGS += -Xcompiler -ffloat-store
|
||||
|
||||
# Compiler
|
||||
CC := g++
|
||||
|
||||
CUDART := "$(shell dirname $(NVCC))/../lib64/libcudart_static.a"
|
||||
|
||||
endif
|
||||
|
||||
# compiler=clang Enables compilation with clang.
|
||||
|
||||
ifeq ($(compiler), clang)
|
||||
# NVCC_VERSION is used as the proxy for the CUDA version.
|
||||
BIN_SUFFIX := sm$(SM_ARCH)_clang_cuda_$(NVCC_VERSION)
|
||||
# Clangs needs few extra flags to point it to CUDA SDK
|
||||
# and link the binaries with CUDA runtime.
|
||||
CUDA_BASE=$(realpath $(join $(dir $(shell which nvcc)), ..))
|
||||
CLANG_CFLAGS += --cuda-path=$(CUDA_BASE)
|
||||
LIBINC += -L$(CUDA_BASE)/lib64 -Wl,-rpath=$(CUDA_BASE)/lib64
|
||||
LIBS += -lcudart
|
||||
|
||||
# Replace NVCC and its options with clang++.
|
||||
NVCC = clang++
|
||||
NVCCFLAGS = $(CLANG_CFLAGS)
|
||||
SM_TARGETS = $(CLANG_SM_TARGETS)
|
||||
else
|
||||
# Suffix to append to each binary
|
||||
BIN_SUFFIX := sm$(SM_ARCH)_nvcc_$(NVCC_VERSION)
|
||||
endif
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Function for computing dependency Lists
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
|
||||
102
cutlass/convert.h
Normal file
102
cutlass/convert.h
Normal file
@ -0,0 +1,102 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Defines conversion operations among Fragments of different base type.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputFragment_>
|
||||
struct Convert {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputScalar_, int kScalars_>
|
||||
struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
|
||||
/// The input fragment.
|
||||
typedef Fragment<InputScalar_, kScalars_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<OutputScalar_, kScalars_> OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Convert() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
for (int i = 0; i < kScalars_; ++i) {
|
||||
dst[i] = static_cast<OutputScalar_>(src[i + offset]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_>
|
||||
struct Copy {
|
||||
/// The input fragment.
|
||||
typedef Fragment_ InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment_ OutputFragment;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Copy() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename InputFragment_>
|
||||
CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
|
||||
if (sizeof(typename Fragment_::Element) == 8) {
|
||||
uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
|
||||
uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
} else {
|
||||
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
|
||||
for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
|
||||
dst_ptr[i] = src_ptr[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
287
cutlass/coord.h
Normal file
287
cutlass/coord.h
Normal file
@ -0,0 +1,287 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief A Coord is a coordinate of arbitrary rank into a tensor or matrix
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Describes identity elements
|
||||
struct Identity {
|
||||
/// Enumeration describing identity elements. Value assignments are significant.
|
||||
/// Feel free to add or multiply by these, respectively.
|
||||
enum Kind { Additive = 0, Multiplicative = 1 };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically-sized array specifying Coords within a tensor
|
||||
template <int N_>
|
||||
struct Coord {
|
||||
//
|
||||
// Type and constant definitions
|
||||
//
|
||||
|
||||
static int const N = N_;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Indices
|
||||
int idx[N];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor initializes uniformly
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int value = 0) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int _idx[]) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = _idx[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator+(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] + b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator-(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] - b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator*(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] * b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator/(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
c.idx[i] = idx[i] / b.idx[i];
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator+=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] += b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator-=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] -= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator*=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] *= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator/=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] /= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
|
||||
T sum = T(0);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int& at() {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int& at(int dim) { return idx[dim]; }
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int const& at() const {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& at(int dim) const { return idx[dim]; }
|
||||
|
||||
/// Determines if two Coord<> objects are equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Coord<N> const& b) const {
|
||||
bool equal = true;
|
||||
for (int i = 0; equal && i < N; ++i) {
|
||||
equal = (idx[i] == b.idx[i]);
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
||||
/// Not equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Coord<N> const& b) const { return !(*this == b); }
|
||||
|
||||
/// Clamps a coordinate to a range specified by maximum and minimum values
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& clamp(Coord<N> const& max, Coord<N> const& min = Coord<N>()) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the product of all elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int count() const {
|
||||
int product = idx[0];
|
||||
for (int i = 1; i < N; ++i) {
|
||||
product *= idx[i];
|
||||
}
|
||||
return product;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<1> make_Coord(int _0) {
|
||||
int values[1] = {_0};
|
||||
return Coord<1>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 2-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> make_Coord(int _0, int _1) {
|
||||
int values[2] = {_0, _1};
|
||||
return Coord<2>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 3-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> make_Coord(int _0, int _1, int _2) {
|
||||
int values[3] = {_0, _1, _2};
|
||||
return Coord<3>(values);
|
||||
}
|
||||
|
||||
/// Helper to make a 4-element coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
int values[4] = {_0, _1, _2, _3};
|
||||
return Coord<4>(values);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
44
cutlass/core_io.h
Normal file
44
cutlass/core_io.h
Normal file
@ -0,0 +1,44 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief Helpers for printing cutlass/core objects
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iosfwd>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
|
||||
template <int Rank>
|
||||
std::ostream& operator<<(std::ostream& out, cutlass::Coord<Rank> const& coord) {
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
out << (i ? ", " : "") << coord.idx[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
73
cutlass/cutlass.h
Normal file
73
cutlass/cutlass.h
Normal file
@ -0,0 +1,73 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Basic include for CUTLASS macros
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CUTLASS_MAJOR 1
|
||||
#define CUTLASS_MINOR 0
|
||||
#define CUTLASS_PATCH 1
|
||||
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
|
||||
|
||||
#ifdef __NVCC__
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CUTLASS_HOST_DEVICE __forceinline__ __device__
|
||||
#define CUTLASS_DEVICE __forceinline__ __device__
|
||||
#else
|
||||
#define CUTLASS_HOST_DEVICE
|
||||
// CUTLASS_DEVICE is an error if not compiling device code
|
||||
#endif
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(_MSC_VER)
|
||||
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
||||
#endif
|
||||
#else
|
||||
#define CUTLASS_PRAGMA_UNROLL
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif
|
||||
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/// NVIDIA GPU Warp size
|
||||
static const int kWarpSize = 32;
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
278
cutlass/fragment.h
Normal file
278
cutlass/fragment.h
Normal file
@ -0,0 +1,278 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines Fragment, a statically-sized array for storing parts of matrices within a
|
||||
thread's registers.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/util/cutlass_math.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_concept Fragment Concept
|
||||
@{
|
||||
|
||||
\ref fragment_concept is a statically sized array for storing parts of tiles held by individual CUDA
|
||||
threads.
|
||||
|
||||
@par \ref fragment_concept
|
||||
Types satisfying \ref fragment_concept define the following members
|
||||
- <b>Element</b> - type of each access held within the fragment
|
||||
- <b>kElements</b> - number of elements stored by the fragment
|
||||
- <b>clear()</b> - overwrites the fragment storage with zeros
|
||||
- <b>Element & operator[](int i)</b> - by-reference access of the ith element
|
||||
- <b>Element const & operator[](int i) const</b> - const by-reference access of the ith element
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup fragment_iterator_concept Fragment Iterator Concept
|
||||
@{
|
||||
|
||||
\ref fragment_iterator_concept provides structured access to the elements within a fragment with an
|
||||
optional bitcast to the desired access type
|
||||
|
||||
@par \ref fragment_iterator_concept
|
||||
Types satisfying \ref fragment_iterator_concept define the following members
|
||||
- <b>AccessType& operator[](int i)</b> - provides access to the ith element of the fragment
|
||||
- <b>AccessType& at(int d, int h, int w, int c)</b> - applies \ref layout_concept to fragment and
|
||||
provides access to element at (d, h, w, c)
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAlignment_>
|
||||
struct StorageType {
|
||||
typedef uint64_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<4> {
|
||||
typedef uint32_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<2> {
|
||||
typedef uint16_t Type;
|
||||
};
|
||||
template <>
|
||||
struct StorageType<1> {
|
||||
typedef uint8_t Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_concept
|
||||
* @concept{fragment_concept}
|
||||
*/
|
||||
template <typename Element_, int kElements_, size_t kAlignment_ = 16>
|
||||
struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
/// Make sure the alignment makes sense wrt the size of elements.
|
||||
static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
|
||||
/// Alignment must be a power of two
|
||||
static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
|
||||
|
||||
/// This class.
|
||||
typedef Fragment<Element_, kElements_> This_;
|
||||
/// The element.
|
||||
typedef Element_ Element;
|
||||
/// The number of elements.
|
||||
static int const kElements = kElements_;
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_DEVICE void clear() {
|
||||
// Avoid element-wise access for sub 32b element type
|
||||
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
|
||||
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
|
||||
ptr[i] = uint64_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
|
||||
uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
|
||||
ptr[i] = uint32_t(0);
|
||||
}
|
||||
} else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
|
||||
uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
|
||||
for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
|
||||
ptr[i] = uint16_t(0);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < kElements; ++i) {
|
||||
storage[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element& operator[](int i) {
|
||||
assert(i < kElements_);
|
||||
return reinterpret_cast<Element*>(storage)[i];
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element const& operator[](int i) const {
|
||||
assert(i < kElements_);
|
||||
return reinterpret_cast<Element const*>(storage)[i];
|
||||
}
|
||||
|
||||
private:
|
||||
/// Storage type to use for Elements
|
||||
typedef typename StorageType<kAlignment_>::Type StorageType;
|
||||
|
||||
/// Number of elements in the storage
|
||||
static int const kStorageCount =
|
||||
(sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
|
||||
/// The storage.
|
||||
StorageType storage[kStorageCount];
|
||||
|
||||
/// Ensure that there's enough storage for all elements
|
||||
static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_iterator_concept
|
||||
* @concept{fragment_iterator_concept}
|
||||
*/
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape Strides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& operator[](int i) {
|
||||
return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Fragment_, typename Iterations_, typename AccessType_>
|
||||
struct FragmentConstIterator {
|
||||
/// This class.
|
||||
typedef FragmentIterator<Fragment_, Iterations_, AccessType_> This_;
|
||||
/// The fragment.
|
||||
typedef Fragment_ Fragment;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The access type.
|
||||
typedef AccessType_ AccessType;
|
||||
|
||||
/// The element.
|
||||
typedef typename Fragment::Element Element;
|
||||
/// The number of elements per access.
|
||||
static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
|
||||
/// The shape of the the fragment.
|
||||
typedef typename ShapeMul<Iterations, Shape<1, 1, 1, kElementsPerAccess> >::Shape FragmentShape;
|
||||
/// The linear strides for iterations.
|
||||
typedef typename ShapeStrides<FragmentShape, kElementsPerAccess>::Shape IterationsStrides;
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
/// Create from non-constant FragmentIterator
|
||||
CUTLASS_DEVICE FragmentConstIterator(
|
||||
FragmentIterator<Fragment_, Iterations_, AccessType_> const& rhs_)
|
||||
: pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element const* pointer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
135
cutlass/fragment_load_store.h
Normal file
135
cutlass/fragment_load_store.h
Normal file
@ -0,0 +1,135 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines accessors for loading and storing fragments to memory efficiently.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
value.load(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
Load<Scalar_, kAccessSize, Memory_>::load(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
value.store(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
Store<Scalar_, kAccessSize, Memory_>::store(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} /// namespace cutlass
|
||||
149
cutlass/fragment_multiply_add.h
Normal file
149
cutlass/fragment_multiply_add.h
Normal file
@ -0,0 +1,149 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines multiply-add operations on fragments within a thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
struct FragmentMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef Scalar_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef Scalar_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef Scalar_ ScalarC;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(Scalar_ a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0] + c[j];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
template <>
|
||||
struct FragmentMultiplyAdd<half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The input.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(half a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The inputs.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,162 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* block-wide tile-loading abstractions
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* load_algorithm
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Enumeration of matrix loading algorithms
|
||||
*/
|
||||
struct load_algorithm
|
||||
{
|
||||
/// \brief Enumerants. See corresponding tag types.
|
||||
enum kind_t
|
||||
{
|
||||
CongruousCopy = 0,
|
||||
CrosswiseCopy = 1,
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Generic tag
|
||||
*/
|
||||
template <kind_t Kind>
|
||||
struct any_tag : nv_std::integral_constant<kind_t, Kind> {};
|
||||
|
||||
/**
|
||||
* \brief Copy from a global matrix that is row-major in relation
|
||||
* to the local row-major tile
|
||||
*/
|
||||
typedef any_tag<CongruousCopy> contiguous_tag_t;
|
||||
|
||||
/**
|
||||
* \brief Copy from a global matrix that is column-major in relation
|
||||
* to the local row-major tile
|
||||
*/
|
||||
typedef any_tag<CrosswiseCopy> crosswise_tag_t;
|
||||
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_loader
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A three-phase data loading abstraction (prefetch, commit, and
|
||||
* advance) for iterating over ranges of block-wide matrix tiles.
|
||||
*
|
||||
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
|
||||
* value_t in shared memory. The layout of the shared
|
||||
* block-wide tile is a row-major (L-major) tiling of dp_vector_t items, which are
|
||||
* themselves column-major (K-major) vectors of value_t. Its dimensions are:
|
||||
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
|
||||
* L = BlockDpVectorsL
|
||||
*
|
||||
* NB: This generic class is not directly constructible. Architecture- and
|
||||
* algorithm-specific template specializations will provide the API
|
||||
* functionality prescribed here.
|
||||
*
|
||||
*/
|
||||
template <
|
||||
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
|
||||
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
|
||||
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
|
||||
typename value_t, ///< Input matrix value type
|
||||
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
|
||||
bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
typename dp_vector_t, ///< Dot-product vector type along the K-axis
|
||||
load_algorithm::kind_t LoadAlgorithm> ///< Algorithm for loading a shared tile of KxL matrix data
|
||||
struct block_loader
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
block_loader(
|
||||
value_t *d_matrix, ///< Pointer to input matrix
|
||||
int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
|
||||
int matrix_values_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_values_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int2 block_begin_item_coords, ///< Thread block's starting value_t coordinates (l, k) within the input matrix
|
||||
int block_end_item_k); ///< Thread block's ending coordinate (k) within the input matrix (one-past)
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Loader API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the current block-wide tile
|
||||
*/
|
||||
void request();
|
||||
|
||||
|
||||
/**
|
||||
* Advance the loader to the next block-wide tile in the K-axis
|
||||
*/
|
||||
void next();
|
||||
|
||||
|
||||
/**
|
||||
* Commit the previously-requested block-wide tile to shared memory
|
||||
*
|
||||
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
|
||||
* allow the row stride _BlockDpVectorsL to be arbitrarily bigger than the
|
||||
* tile width BlockDpVectorsL.
|
||||
*/
|
||||
template <int _BlockDpVectorsL>
|
||||
void commit(
|
||||
dp_vector_t (&scratch_tile)[BlockDpVectorsK][_BlockDpVectorsL]);
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Tail-include specializations that adhere to the block_loader API
|
||||
******************************************************************************/
|
||||
|
||||
#include "block_loader_crosswise.h"
|
||||
#include "block_loader_congruous_dp1.h"
|
||||
#include "block_loader_congruous_idp4.h"
|
||||
@ -1,406 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Tile-loading abstraction for thread blocks
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_loader (CongruousCopy + dp1 specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A three-phase data loading abstraction (prefetch, commit, and
|
||||
* advance) for iterating over ranges of block-wide matrix tiles.
|
||||
* (CongruousCopy + dp1 specialization)
|
||||
*
|
||||
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
|
||||
* value_t in shared memory. The layout of the shared block-wide tile is
|
||||
* a row-major (L-major) tiling of singleton "dp1" dp_vector_t items, where
|
||||
* dp_vector_t == value_t. Its dimensions are:
|
||||
* K = BlockDpVectorsK
|
||||
* L = BlockDpVectorsL
|
||||
*
|
||||
* The data is copied from a corresponding tile of global matrix data whose
|
||||
* layout of value_t is also L-major. This constitutes a CongruousCopy
|
||||
* between the L-major global tile and the L-major shared tile.
|
||||
*
|
||||
* NB: Because they are "dp1" singletons, the K-major orientation of
|
||||
* dp_vector_t in shared memory is irrelevant, and the L-major global and
|
||||
* shared tile layouts are perfectly congruous. As a result, we can increase
|
||||
* the granularity of data transfer via vectorization of loads and stores
|
||||
* without any intermediate {dis|re}assembly.
|
||||
*
|
||||
* NB: Consecutive threads within a block are mapped in L-major
|
||||
* fashion across a first-set of LDG-vectors of dp_vector_t (value_t) within
|
||||
* their global tile. Successive sets of LDG-vectors are then strip-mined
|
||||
* as necessary down the K-axis. These discontiguous LDG-vectors comprise the
|
||||
* thread's "slice" of the block-wide tile.
|
||||
*/
|
||||
template <
|
||||
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
|
||||
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
|
||||
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
|
||||
typename value_t, ///< Input matrix value type
|
||||
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
|
||||
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
>
|
||||
struct block_loader<
|
||||
BlockThreads,
|
||||
BlockDpVectorsK,
|
||||
BlockDpVectorsL,
|
||||
value_t,
|
||||
LeadingDimAlignBytes,
|
||||
AllowRaggedTiles,
|
||||
value_t, ///< Dot-product vector type along the K-axis (dp1 specialization)
|
||||
load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CongruousCopy specialization)
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Dot-product vector type along the K-axis
|
||||
typedef value_t dp_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of value_t in a dp_vector_t
|
||||
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
|
||||
|
||||
/// Number of dp_vector_t in a block-wide tile
|
||||
BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
|
||||
|
||||
/// Number of dp_vector_t in a thread-tile
|
||||
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
|
||||
};
|
||||
|
||||
/// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
|
||||
/// smaller of either ThreadDpVectors or BlockDpVectorsL
|
||||
typedef io_vector<
|
||||
dp_vector_t,
|
||||
__NV_STD_MIN(ThreadDpVectors, BlockDpVectorsL),
|
||||
LeadingDimAlignBytes>
|
||||
ldg_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of dp_vector_t per ldg_vector_t
|
||||
LdgVectorDpVectors = ldg_vector_t::VectorItems,
|
||||
|
||||
/// Number of value_t per ldg_vector_t
|
||||
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
|
||||
|
||||
|
||||
|
||||
/// Total number of ldg_vector_t within each block-wide tile
|
||||
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along L-axis
|
||||
BlockLdgVectorsL = divide_assert<BlockDpVectorsL, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along K-axis
|
||||
BlockLdgVectorsK = BlockDpVectorsK,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each thread-tile
|
||||
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along L-axis
|
||||
ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along K-axis
|
||||
ThreadLdgVectorsK = divide_assert<ThreadLdgVectors, ThreadLdgVectorsL>::value,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each stripmine-tile
|
||||
StripmineLdgVectors = BlockThreads,
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along L-axis
|
||||
StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, StripmineLdgVectors),
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along K-axis
|
||||
StripmineLdgVectorsK = divide_assert<StripmineLdgVectors, StripmineLdgVectorsL>::value,
|
||||
|
||||
|
||||
|
||||
/// Alignment in dp_vector_t along L needed for committing prefetch
|
||||
AlignmentDpVectorsL = LdgVectorDpVectors,
|
||||
};
|
||||
|
||||
/// Predicate bit vector
|
||||
typedef uint64_t predicate_mask_t;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
static_assert(
|
||||
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
|
||||
"Predicate mask type does not contain enough bits for encoding load predicates");
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Input pointer to matrix in ldg_vector_t
|
||||
ldg_vector_t *d_matrix_ldgvecs;
|
||||
|
||||
/// Extent of the input matrix in ldg_vector_t along the L-axis
|
||||
int matrix_ldgvecs_l;
|
||||
|
||||
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
|
||||
int block_end_ldgvec_k;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
|
||||
predicate_mask_t guard;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
|
||||
predicate_mask_t residue_guard;
|
||||
|
||||
/// Iteration span in "whole-k" block-wide tiles
|
||||
int wholek_tiles_remaining;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_ldgvec_stride_k;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int matrix_ldgvec_stride_l;
|
||||
|
||||
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
int2 block_thread_ldgvec_coords;
|
||||
|
||||
/// Thread-wide tile of prefetch data
|
||||
ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_loader(
|
||||
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
|
||||
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
|
||||
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
|
||||
:
|
||||
block_end_ldgvec_k(block_end_item_k),
|
||||
guard(0),
|
||||
residue_guard(0)
|
||||
{
|
||||
matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
|
||||
matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
|
||||
matrix_ldgvec_stride_l = matrix_items_stride_l;
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
block_thread_ldgvec_coords = make_int2(
|
||||
threadIdx.x % BlockLdgVectorsL, // l-coordinate
|
||||
threadIdx.x / BlockLdgVectorsL); // k-coordinate
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int2 matrix_block_ldgvec_coords = make_int2(
|
||||
matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
|
||||
matrix_block_item_coords.y); // k-coordinate
|
||||
|
||||
// Iteration span in ldg_vector_t
|
||||
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
|
||||
|
||||
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
|
||||
int2 matrix_thread_ldgvec_coords = make_int2(
|
||||
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
|
||||
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
|
||||
|
||||
// Iteration range in "whole-k" block-wide tiles
|
||||
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
|
||||
|
||||
// Extent of final residue-tile in ldg_vector_t along K-axis
|
||||
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
|
||||
|
||||
// Initialize I/O predicates
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
|
||||
|
||||
// Whether block_ldgvec_coords.y is valid in the final residue tile
|
||||
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
|
||||
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
// Whether block_ldgvec_coords.x is valid any block-wide tile
|
||||
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
|
||||
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
|
||||
|
||||
// Set predicate guard bits
|
||||
guard |= (valid_l << ldgvec_idx);
|
||||
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
|
||||
// Update the input pointer to be matrix_thread_ldgvec_coords
|
||||
this->d_matrix_ldgvecs =
|
||||
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
|
||||
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
|
||||
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Loader API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the current block-wide tile
|
||||
*/
|
||||
inline __device__
|
||||
void request()
|
||||
{
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
|
||||
|
||||
// Unpack predicate guard
|
||||
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
|
||||
|
||||
if (!AllowRaggedTiles || valid)
|
||||
{
|
||||
// Perform load
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
|
||||
d_matrix_ldgvecs +
|
||||
(thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
|
||||
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Zero-initialize
|
||||
#pragma unroll
|
||||
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Advance the loader to the next block-wide tile in the K-axis
|
||||
*/
|
||||
inline __device__
|
||||
void next()
|
||||
{
|
||||
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
|
||||
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
--wholek_tiles_remaining;
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Commit the previously-requested block-wide tile to shared memory
|
||||
*
|
||||
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
|
||||
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
|
||||
* tile width BlockDpVectorsL.
|
||||
*/
|
||||
template <int SmemDpVectorsL>
|
||||
inline __device__
|
||||
void commit(
|
||||
dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
|
||||
{
|
||||
static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
|
||||
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
|
||||
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].store(
|
||||
&scratch_tile[block_ldgvec_k][block_ldgvec_l * LdgVectorDpVectors]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,544 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Tile-loading abstraction for thread blocks
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_loader (CongruousCopy + idp4 specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A three-phase data loading abstraction (prefetch, commit, and
|
||||
* advance) for iterating over ranges of block-wide matrix tiles.
|
||||
* (CongruousCopy + idp4 specialization)
|
||||
*
|
||||
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
|
||||
* value_t in shared memory. The layout of the shared block-wide tile is
|
||||
* a row-major (L-major) tiling of int32_t dp_vector_t, which are themselves
|
||||
* column-major (K-major) vectors of int8_t value_t. Its dimensions are:
|
||||
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
|
||||
* L = BlockDpVectorsL
|
||||
*
|
||||
* The data is copied from a corresponding tile of global matrix data whose
|
||||
* layout of value_t is also L-major. This constitutes a CongruousCopy between
|
||||
* the L-major global tile and the L-major shared tile.
|
||||
*
|
||||
* NB: The K-major value_t in shared dp_vector_t are imperfectly congruous
|
||||
* with the L-major value_t in global memory. As a result, the granularity
|
||||
* of data transfer is a "dp-square" of (DpVectorItems * DpVectorItems) values
|
||||
* that must be transposed from L-oriented dp_vector_t to K-oriented
|
||||
* dp_vector_t prior to commitment.
|
||||
*
|
||||
* NB: Consecutive threads within a block are mapped in L-major
|
||||
* fashion across a first-set of squares within their global tile. Successive
|
||||
* sets of squares are then strip-mined as necessary down the K-axis. These
|
||||
* discontiguous squares comprise the thread's "slice" of the block-wide tile.
|
||||
*/
|
||||
template <
|
||||
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
|
||||
int _BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
|
||||
int _BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
|
||||
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
|
||||
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
>
|
||||
struct block_loader<
|
||||
BlockThreads,
|
||||
_BlockDpVectorsK,
|
||||
_BlockDpVectorsL,
|
||||
int8_t, ///< Input matrix value type (idp4 specialization)
|
||||
LeadingDimAlignBytes,
|
||||
AllowRaggedTiles,
|
||||
int32_t, ///< Dot-product vector type along the K-axis (idp4 specialization)
|
||||
load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Input matrix value type
|
||||
typedef int8_t value_t;
|
||||
|
||||
/// Dot-product vector type along the K-axis
|
||||
typedef int32_t dp_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of value_t in a dp_vector_t
|
||||
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
|
||||
|
||||
/// Number of dp_vector_t in a block-wide tile
|
||||
BlockDpVectors = _BlockDpVectorsK * _BlockDpVectorsL,
|
||||
|
||||
/// Number of dp_vector_t in a thread-tile
|
||||
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
|
||||
|
||||
/// Number of dp_vector_t in a dp-square
|
||||
SquareDpVectors = DpVectorItems,
|
||||
|
||||
/// Number of dp-square tiles in a thread-tile
|
||||
ThreadSquares = divide_assert<ThreadDpVectors, SquareDpVectors>::value,
|
||||
|
||||
/// Extent of block-wide tile in transposed dp_vector_t along the K-axis (height)
|
||||
BlockTransDpVectorsK = _BlockDpVectorsK * DpVectorItems,
|
||||
|
||||
/// Extent of block-wide tile in transposed dp_vector_t along the L-axis (height)
|
||||
BlockTransDpVectorsL = divide_assert<_BlockDpVectorsL, DpVectorItems>::value,
|
||||
|
||||
|
||||
};
|
||||
|
||||
/// Load-from-global data movement type, coarsened by LeadingDimAlignBytes, capped by the
|
||||
/// smaller of either ThreadSquares or BlockTransDpVectorsL
|
||||
typedef io_vector<
|
||||
dp_vector_t,
|
||||
__NV_STD_MIN(ThreadSquares, BlockTransDpVectorsL),
|
||||
LeadingDimAlignBytes>
|
||||
ldg_vector_t;
|
||||
|
||||
/// Store-to-shared data movement type equivalent to a dp-square
|
||||
typedef io_vector<
|
||||
dp_vector_t,
|
||||
SquareDpVectors>
|
||||
sts_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of dp_vector_t per ldg_vector_t
|
||||
LdgVectorDpVectors = ldg_vector_t::VectorItems,
|
||||
|
||||
/// Number of value_t per ldg_vector_t
|
||||
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
|
||||
|
||||
|
||||
|
||||
/// Total number of ldg_vector_t within each block-wide tile
|
||||
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along L-axis
|
||||
BlockLdgVectorsL = divide_assert<BlockTransDpVectorsL, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along K-axis
|
||||
BlockLdgVectorsK = BlockTransDpVectorsK,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each thread-tile
|
||||
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along L-axis
|
||||
ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along K-axis
|
||||
ThreadLdgVectorsK = divide_assert<ThreadLdgVectors, ThreadLdgVectorsL>::value,
|
||||
|
||||
/// Extent of the thread tile in dp-square tiles along K-axis
|
||||
ThreadSquaresK = divide_assert<ThreadLdgVectorsK, SquareDpVectors>::value,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each stripmine-tile
|
||||
StripmineLdgVectors = BlockThreads * SquareDpVectors,
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along L-axis
|
||||
StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, BlockThreads),
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along K-axis
|
||||
StripmineLdgVectorsK = divide_assert<StripmineLdgVectors, StripmineLdgVectorsL>::value,
|
||||
|
||||
/// Extent of the stripmine tile in dp-square tiles along K-axis
|
||||
StripmineSquaresK = divide_assert<StripmineLdgVectorsK, SquareDpVectors>::value,
|
||||
|
||||
|
||||
|
||||
/// Alignment in dp_vector_t along L needed for committing prefetch
|
||||
AlignmentDpVectorsL = LdgVectorDpVectors,
|
||||
};
|
||||
|
||||
/// Predicate mask type
|
||||
typedef uint32_t predicate_mask_t;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
static_assert((LeadingDimAlignBytes >= 4) && (LeadingDimAlignBytes % 4 == 0),
|
||||
"Alignment for matrix operands to IGEMM must be a multiple of 4 bytes.");
|
||||
|
||||
static_assert(
|
||||
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
|
||||
"Predicate mask type does not contain enough bits for encoding load predicates");
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Input pointer to matrix in ldg_vector_t
|
||||
ldg_vector_t *d_matrix_ldgvecs;
|
||||
|
||||
/// Extent of the input matrix in ldg_vector_t along the L-axis
|
||||
int matrix_ldgvecs_l;
|
||||
|
||||
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
|
||||
int block_end_ldgvec_k;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
|
||||
predicate_mask_t guard;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
|
||||
predicate_mask_t residue_guard;
|
||||
|
||||
/// Iteration span in "whole-k" block-wide tiles
|
||||
int wholek_tiles_remaining;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_ldgvec_stride_k;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int matrix_ldgvec_stride_l;
|
||||
|
||||
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
int2 block_thread_ldgvec_coords;
|
||||
|
||||
/// Thread-wide tile of prefetch data
|
||||
ldg_vector_t thread_tile[ThreadSquaresK][SquareDpVectors][ThreadLdgVectorsL];
|
||||
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Utility methods
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
|
||||
/**
|
||||
* \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register. For SM2.0 or later.
|
||||
*
|
||||
* \par
|
||||
* The bytes in the two source registers \p a and \p b are numbered from 0 to 7:
|
||||
* {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes
|
||||
* {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within
|
||||
* the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0}
|
||||
*
|
||||
* \par Snippet
|
||||
* The code snippet below illustrates byte-permute.
|
||||
* \par
|
||||
* \code
|
||||
* #include <cub/cub.cuh>
|
||||
*
|
||||
* __global__ void ExampleKernel(...)
|
||||
* {
|
||||
* int a = 0x03020100;
|
||||
* int b = 0x07060504;
|
||||
* int index = 0x00007531;
|
||||
*
|
||||
* int selected = prmt(a, b, index); // 0x07050301
|
||||
*
|
||||
* \endcode
|
||||
*
|
||||
*/
|
||||
inline __device__
|
||||
int32_t prmt(int32_t a, int32_t b, unsigned int index)
|
||||
{
|
||||
int ret;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert a "dp-square" from L-major to K-major
|
||||
*/
|
||||
inline __device__
|
||||
void transpose_dp_square(dp_vector_t (&dp_square)[SquareDpVectors])
|
||||
{
|
||||
// Transpose dp_vector_t squares
|
||||
int32_t y = prmt(dp_square[0], dp_square[1], 0x00007362);
|
||||
int32_t w = prmt(dp_square[2], dp_square[3], 0x00007362);
|
||||
int32_t x = prmt(dp_square[0], dp_square[1], 0x00005140);
|
||||
int32_t z = prmt(dp_square[2], dp_square[3], 0x00005140);
|
||||
|
||||
dp_square[0] = prmt(x, z, 0x00005410);
|
||||
dp_square[1] = prmt(x, z, 0x00007632);
|
||||
dp_square[2] = prmt(y, w, 0x00005410);
|
||||
dp_square[3] = prmt(y, w, 0x00007632);
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_loader(
|
||||
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
|
||||
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
|
||||
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
|
||||
:
|
||||
block_end_ldgvec_k(block_end_item_k),
|
||||
guard(0),
|
||||
residue_guard(0)
|
||||
{
|
||||
matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
|
||||
matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
|
||||
matrix_ldgvec_stride_l = matrix_items_stride_l;
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
block_thread_ldgvec_coords = make_int2(
|
||||
threadIdx.x % BlockLdgVectorsL, // l-coordinate
|
||||
(threadIdx.x / BlockLdgVectorsL) * SquareDpVectors); // k-coordinate
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int2 matrix_block_ldgvec_coords = make_int2(
|
||||
matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
|
||||
matrix_block_item_coords.y); // k-coordinate
|
||||
|
||||
// Iteration span in ldg_vector_t
|
||||
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
|
||||
|
||||
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
|
||||
int2 matrix_thread_ldgvec_coords = make_int2(
|
||||
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
|
||||
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
|
||||
|
||||
// Iteration range in "whole-k" block-wide tiles
|
||||
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
|
||||
|
||||
// Extent of final residue-tile in ldg_vector_t along K-axis
|
||||
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
|
||||
|
||||
// Initialize I/O predicates
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
// Iterate through rows of squares in thread tile
|
||||
#pragma unroll
|
||||
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
|
||||
{
|
||||
// Iterate through rows of dp_vector_t in each square
|
||||
#pragma unroll
|
||||
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
|
||||
{
|
||||
// ldg_vector_t K-coordinate in block-wide tile (K-axis strip-mining of ldg_vector_t within block-tile)
|
||||
int block_ldgvec_k =
|
||||
block_thread_ldgvec_coords.y +
|
||||
(thread_square_k * StripmineLdgVectorsK) +
|
||||
square_dpvec;
|
||||
|
||||
// Whether block_ldgvec_coords.y is valid in the final residue tile
|
||||
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
|
||||
|
||||
// L-axis strip-mining of block-tile
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
// ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
// Whether block_ldgvec_coords.x is valid any block-wide tile
|
||||
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
|
||||
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx =
|
||||
(thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
|
||||
(square_dpvec * ThreadLdgVectorsL) +
|
||||
thread_ldgvec_l;
|
||||
|
||||
// Set predicate guard bits
|
||||
guard |= (valid_l << ldgvec_idx);
|
||||
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
|
||||
// Update the input pointer to be matrix_thread_ldgvec_coords
|
||||
this->d_matrix_ldgvecs =
|
||||
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
|
||||
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
|
||||
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Loader API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the current block-wide tile
|
||||
*/
|
||||
inline __device__
|
||||
void request()
|
||||
{
|
||||
// Each thread iterates through the ldg_vector_t in its thread tile
|
||||
|
||||
// Iterate through rows of squares in thread tile
|
||||
#pragma unroll
|
||||
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
|
||||
{
|
||||
// Iterate through rows of dp_vector_t in each square
|
||||
#pragma unroll
|
||||
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
|
||||
{
|
||||
// Iterate through ldg_vector_t in each row
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx =
|
||||
(thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
|
||||
(square_dpvec * ThreadLdgVectorsL) +
|
||||
thread_ldgvec_l;
|
||||
|
||||
// Unpack predicate guard
|
||||
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
|
||||
|
||||
if (!AllowRaggedTiles || valid)
|
||||
{
|
||||
// Perform load
|
||||
thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].load(
|
||||
d_matrix_ldgvecs +
|
||||
(((thread_square_k * StripmineLdgVectorsK) + square_dpvec) * matrix_ldgvec_stride_k) +
|
||||
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Zero-initialize
|
||||
#pragma unroll
|
||||
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
|
||||
thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[dpvec] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Advance the loader to the next block-wide tile in the K-axis
|
||||
*/
|
||||
inline __device__
|
||||
void next()
|
||||
{
|
||||
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
|
||||
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
--wholek_tiles_remaining;
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Commit the previously-requested block-wide tile to shared memory
|
||||
*
|
||||
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
|
||||
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
|
||||
* tile width BlockDpVectorsL.
|
||||
*/
|
||||
template <int SmemDpVectorsL>
|
||||
inline __device__
|
||||
void commit(
|
||||
dp_vector_t (&scratch_tile)[_BlockDpVectorsK][SmemDpVectorsL])
|
||||
{
|
||||
static_assert(SmemDpVectorsL >= _BlockDpVectorsL, "Row stride must be >= tile width.");
|
||||
|
||||
// Square K-coordinate of thread tile in block-wide tile
|
||||
int block_thread_square_k = block_thread_ldgvec_coords.y / SquareDpVectors;
|
||||
|
||||
// Iterate through rows of squares in thread tile
|
||||
#pragma unroll
|
||||
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
|
||||
{
|
||||
// Square K-coordinate in block-wide tile (K-axis strip-mining of squares within block-tile)
|
||||
int block_square_k = block_thread_square_k + (thread_square_k * StripmineSquaresK);
|
||||
|
||||
// Iterate through ldg_vector_t in each row
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
// ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
// Iterate through squares in each ldg_vector_t
|
||||
#pragma unroll
|
||||
for (int ldgvec_dpvec_l = 0; ldgvec_dpvec_l < LdgVectorDpVectors; ++ldgvec_dpvec_l)
|
||||
{
|
||||
// Square L-coordinate in block-wide tile (L-axis raking of square-slices within ldg_vector_t)
|
||||
int block_square_l = (block_ldgvec_l * LdgVectorDpVectors) + ldgvec_dpvec_l;
|
||||
|
||||
// Assemble square of L-major dp_vector_t from stack of slices
|
||||
sts_vector_t square;
|
||||
|
||||
// Iterate through rows of dp_vector_t in each square
|
||||
#pragma unroll
|
||||
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
|
||||
{
|
||||
square.buff[square_dpvec] = thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[ldgvec_dpvec_l];
|
||||
}
|
||||
|
||||
// Un-transpose square from L-major to K-major
|
||||
transpose_dp_square(square.buff);
|
||||
|
||||
// Store dp-square
|
||||
square.store(&scratch_tile[block_square_k][block_square_l * SquareDpVectors]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,411 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Tile-loading abstraction for thread blocks
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_loader (CrosswiseCopy specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A three-phase data loading abstraction (prefetch, commit, and
|
||||
* advance) for iterating over ranges of block-wide matrix tiles.
|
||||
* (CrosswiseCopy specialization)
|
||||
*
|
||||
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
|
||||
* value_t in shared memory. The layout of the shared block-wide tile is
|
||||
* a row-major (L-major) tiling of dp_vector_t items, which are themselves
|
||||
* column-major (K-major) vectors of value_t. Its dimensions are:
|
||||
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
|
||||
* L = BlockDpVectorsL
|
||||
*
|
||||
* The data is copied from a corresponding tile of global matrix data whose
|
||||
* layout of value_t is K-major. This constitutes a CrosswiseCopy between
|
||||
* the K-major global tile and the L-major shared tile.
|
||||
*
|
||||
* NB: The orientation of dp_vector_t components in shared memory is congruous
|
||||
* with the global matrix data, so we can use dp_vector_t as the minimum
|
||||
* granularity of data transfer without any intermediate {dis|re}assembly
|
||||
* of its value_t components. However, the global and shared memory layouts
|
||||
* of dp_vector_t items are cross-wise with respect to each other, so any
|
||||
* further LDG-vectorization of dp_vector_t data requires intermediate
|
||||
* disassembly into dp_vector_t components to be stored individually into
|
||||
* the shared tile.
|
||||
*
|
||||
* NB: Consecutive threads within a block are mapped in K-major
|
||||
* fashion down a first set of LDG-vectors of dp_vector_t within their global
|
||||
* tile. Successive sets of LDG-vectors are then strip-mined as necessary
|
||||
* across the L-axis. These discontiguous LDG-vectors comprise the thread's
|
||||
* "slice" of the block-wide tile.
|
||||
*/
|
||||
template <
|
||||
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
|
||||
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
|
||||
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
|
||||
typename value_t, ///< Input matrix value type
|
||||
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
|
||||
bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
typename dp_vector_t> ///< Dot-product vector type along the K-axis
|
||||
struct block_loader<
|
||||
BlockThreads,
|
||||
BlockDpVectorsK,
|
||||
BlockDpVectorsL,
|
||||
value_t,
|
||||
LeadingDimAlignBytes,
|
||||
AllowRaggedTiles,
|
||||
dp_vector_t,
|
||||
load_algorithm::CrosswiseCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of value_t in a dp_vector_t
|
||||
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
|
||||
|
||||
/// Number of dp_vector_t in a block-wide tile
|
||||
BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
|
||||
|
||||
/// Number of dp_vector_t in a thread-tile
|
||||
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
|
||||
};
|
||||
|
||||
/// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
|
||||
/// smaller of either ThreadDpVectors or BlockDpVectorsK
|
||||
typedef io_vector<
|
||||
dp_vector_t,
|
||||
__NV_STD_MIN(ThreadDpVectors, BlockDpVectorsK),
|
||||
LeadingDimAlignBytes>
|
||||
ldg_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of dp_vector_t per ldg_vector_t
|
||||
LdgVectorDpVectors = ldg_vector_t::VectorItems,
|
||||
|
||||
/// Number of value_t per ldg_vector_t
|
||||
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
|
||||
|
||||
|
||||
|
||||
/// Total number of ldg_vector_t within each block-wide tile
|
||||
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along K-axis
|
||||
BlockLdgVectorsK = divide_assert<BlockDpVectorsK, LdgVectorDpVectors>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along L-axis
|
||||
BlockLdgVectorsL = BlockDpVectorsL,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each thread-tile
|
||||
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along K-axis
|
||||
ThreadLdgVectorsK = __NV_STD_MAX(1, (BlockLdgVectorsK / BlockThreads)),
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along L-axis
|
||||
ThreadLdgVectorsL = divide_assert<ThreadLdgVectors, ThreadLdgVectorsK>::value,
|
||||
|
||||
|
||||
|
||||
/// Number of ldg_vector_t within each stripmine-tile
|
||||
StripmineLdgVectors = BlockThreads,
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along K-axis
|
||||
StripmineLdgVectorsK = __NV_STD_MIN(BlockLdgVectorsK, StripmineLdgVectors),
|
||||
|
||||
/// Extent of the stripmine tile in ldg_vector_t along L-axis
|
||||
StripmineLdgVectorsL = divide_assert<StripmineLdgVectors, StripmineLdgVectorsK>::value,
|
||||
|
||||
|
||||
|
||||
/// Alignment in dp_vector_t along L needed for committing prefetch
|
||||
AlignmentDpVectorsL = 1,
|
||||
};
|
||||
|
||||
/// Predicate bit vector
|
||||
typedef uint64_t predicate_mask_t;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
static_assert(
|
||||
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
|
||||
"Predicate mask type does not contain enough bits for encoding load predicates");
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Input pointer to matrix in ldg_vector_t
|
||||
ldg_vector_t *d_matrix_ldgvecs;
|
||||
|
||||
/// Extent of the input matrix in ldg_vector_t along the L-axis
|
||||
int matrix_ldgvecs_l;
|
||||
|
||||
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
|
||||
int block_end_ldgvec_k;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
|
||||
predicate_mask_t guard;
|
||||
|
||||
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
|
||||
predicate_mask_t residue_guard;
|
||||
|
||||
/// Iteration span in "whole-k" block-wide tiles
|
||||
int wholek_tiles_remaining;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_ldgvec_stride_k;
|
||||
|
||||
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int matrix_ldgvec_stride_l;
|
||||
|
||||
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
int2 block_thread_ldgvec_coords;
|
||||
|
||||
/// Thread-wide tile of prefetch data
|
||||
ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_loader(
|
||||
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
|
||||
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
|
||||
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
|
||||
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
|
||||
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
|
||||
:
|
||||
block_end_ldgvec_k(block_end_item_k),
|
||||
guard(0),
|
||||
residue_guard(0)
|
||||
{
|
||||
matrix_ldgvecs_l = matrix_items_l;
|
||||
matrix_ldgvec_stride_k = matrix_items_stride_k;
|
||||
matrix_ldgvec_stride_l = (matrix_items_stride_l / LdgVectorItems);
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
|
||||
block_thread_ldgvec_coords = make_int2(
|
||||
(threadIdx.x / BlockLdgVectorsK), // l-coordinate
|
||||
(threadIdx.x % BlockLdgVectorsK)); // k-coordinate
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
|
||||
int2 matrix_block_ldgvec_coords = make_int2(
|
||||
matrix_block_item_coords.x, // l-coordinate
|
||||
matrix_block_item_coords.y / LdgVectorItems); // k-coordinate
|
||||
|
||||
// Iteration span in ldg_vector_t
|
||||
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y) / LdgVectorItems;
|
||||
|
||||
|
||||
|
||||
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
|
||||
int2 matrix_thread_ldgvec_coords = make_int2(
|
||||
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
|
||||
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
|
||||
|
||||
// Iteration range in "whole-k" block-wide tiles
|
||||
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
|
||||
|
||||
// Extent of final residue-tile in ldg_vector_t along K-axis
|
||||
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
|
||||
|
||||
// Initialize I/O predicates
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
|
||||
|
||||
// Whether block_ldgvec_coords.y is valid in the final residue tile
|
||||
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
|
||||
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
// Whether block_ldgvec_coords.x is valid any block-wide tile
|
||||
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
|
||||
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
|
||||
|
||||
// Set predicate guard bits
|
||||
guard |= (valid_l << ldgvec_idx);
|
||||
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
|
||||
// Update the input pointer to be matrix_thread_ldgvec_coords
|
||||
this->d_matrix_ldgvecs =
|
||||
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
|
||||
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
|
||||
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Loader API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the current block-wide tile
|
||||
*/
|
||||
inline __device__
|
||||
void request()
|
||||
{
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
// Linear index of ldg_vector_t load
|
||||
int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
|
||||
|
||||
// Unpack predicate guard
|
||||
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
|
||||
|
||||
if (!AllowRaggedTiles || valid)
|
||||
{
|
||||
// Perform load
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
|
||||
d_matrix_ldgvecs +
|
||||
(thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
|
||||
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Zero-initialize
|
||||
#pragma unroll
|
||||
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Advance the loader to the next block-wide tile in the K-axis
|
||||
*/
|
||||
inline __device__
|
||||
void next()
|
||||
{
|
||||
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
|
||||
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
--wholek_tiles_remaining;
|
||||
|
||||
// Promote residue-guard to primary-guard if no full tiles remain
|
||||
if (!wholek_tiles_remaining)
|
||||
{
|
||||
guard = residue_guard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Commit the previously-requested block-wide tile to shared memory
|
||||
*
|
||||
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
|
||||
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
|
||||
* tile width BlockDpVectorsL.
|
||||
*/
|
||||
template <int SmemDpVectorsL>
|
||||
inline __device__
|
||||
void commit(
|
||||
dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
|
||||
{
|
||||
static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
|
||||
|
||||
// Outer thread-tile ldg_vector_t iteration (K-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
|
||||
{
|
||||
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
|
||||
|
||||
// Inner thread-tile ldg_vector_t iteration (L-axis)
|
||||
#pragma unroll
|
||||
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
|
||||
{
|
||||
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
|
||||
|
||||
// Write column of dp_vector_t
|
||||
#pragma unroll
|
||||
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
|
||||
{
|
||||
scratch_tile[(block_ldgvec_k * LdgVectorDpVectors) + dpvec][block_ldgvec_l] =
|
||||
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,322 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Tile-loading abstraction for thread blocks
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/**
|
||||
* block-wide tile loader supporting congruous mapping of data from source and
|
||||
* destination addressable storage. Typically, this will be used to load a
|
||||
* block-wide tile from global memory into shared memory.
|
||||
*
|
||||
* This enables the caller to specify MatrixAlignBytes guarantees of the input pointer
|
||||
* and performs memory operations on vectors. This increases the efficiency of
|
||||
* memory operations and reduces the number of guard predicates needed.
|
||||
*
|
||||
*/
|
||||
template <
|
||||
bool congruous, ///< Indicates whether the "GEMM K" dimension refers to strided matrix dimension
|
||||
int BlockThreads, ///< Number of threads participating in the streaming operation
|
||||
int BlockItemsL, ///< Extent of block-wide tile in value_t along the L-axis (width)
|
||||
int BlockItemsK, ///< Extent of block-wide tile in value_t along the K-axis (height)
|
||||
typename value_t, ///< Input matrix value type
|
||||
int MatrixAlignBytes, ///< Byte alignment of input matrix
|
||||
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
>
|
||||
struct block_loader_wmma
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Predicate bit vector
|
||||
typedef uint64_t predicate_mask_t;
|
||||
|
||||
/// Data movement type, coarsened by MatrixAlignBytes
|
||||
typedef io_vector<
|
||||
value_t,
|
||||
divide_assert<MatrixAlignBytes, sizeof(value_t)>::value,
|
||||
MatrixAlignBytes>
|
||||
ldg_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of items per ldg_vector_t
|
||||
LdgVectorItems = ldg_vector_t::VectorItems,
|
||||
|
||||
/// Total number of ldg_vector_t within the block-wide tile
|
||||
BlockLdgVectors = divide_assert<(BlockItemsL * BlockItemsK), LdgVectorItems>::value,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along K-axis
|
||||
BlockLdgVectorsK = BlockItemsK,
|
||||
|
||||
/// Extent of the block-wide tile in ldg_vector_t along L-axis
|
||||
BlockLdgVectorsL = divide_assert<BlockItemsL, LdgVectorItems>::value,
|
||||
|
||||
/// Number of ldg_vector_t within each thread tile
|
||||
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along the L-axis
|
||||
ThreadLdgVectorsL = __NV_STD_MAX(1, BlockLdgVectorsL / BlockThreads),
|
||||
|
||||
/// Block-wide strip-mining distance between ldg_vector_t along the K-axis
|
||||
BlockLdgVectorStrideK = __NV_STD_MAX(1, BlockThreads / BlockLdgVectorsL),
|
||||
|
||||
/// Extent of the thread tile in ldg_vector_t along the K-axis
|
||||
ThreadLdgVectorsK = divide_assert<BlockLdgVectorsK, BlockLdgVectorStrideK>::value,
|
||||
};
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Define assertions
|
||||
static_assert(ThreadLdgVectorsL * ThreadLdgVectorsK == ThreadLdgVectors,
|
||||
"Number of vectors must be fully covered by the thread's 2D vector tile.");
|
||||
|
||||
/// Predicate masks must be large enough to guard every vector load
|
||||
static_assert(sizeof(predicate_mask_t) * 8 >= ThreadLdgVectorsL * ThreadLdgVectorsK,
|
||||
"Predicate bit vector must be large enough to guard every vector load.");
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// pointer to tile in global memory
|
||||
const ldg_vector_t *ptr;
|
||||
|
||||
/// stride of the matrix in the K-axis
|
||||
int matrix_values_stride_k;
|
||||
|
||||
/// Guard predicate
|
||||
predicate_mask_t guard;
|
||||
|
||||
/// Guard for the last request iteration
|
||||
predicate_mask_t residue_guard;
|
||||
|
||||
/// Number of 'whole' request iterations before encountering the residue
|
||||
int request_iterations;
|
||||
|
||||
/// fetch registers
|
||||
ldg_vector_t fetch[ThreadLdgVectors];
|
||||
|
||||
/// Thread's base offset from the start of a block-wide tile
|
||||
int thread_offset_l;
|
||||
|
||||
/// Thread's basae offset from the start of a block-wide tile
|
||||
int thread_offset_k;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_loader_wmma(
|
||||
const value_t *d_matrix, ///< Pointer to input matrix
|
||||
int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
|
||||
int start_l, ///< Starting location in tile
|
||||
int dim_k, ///< Inner dimension of tile, used for computing guard predicates
|
||||
int _matrix_values_stride_k, ///< Stride of K-axis of atrix
|
||||
int start_k, ///< Tile's starting location
|
||||
int2 block_begin_item_coords) ///< Thread block's starting value_t coordinates (l, k) within the input matrix
|
||||
:
|
||||
ptr(reinterpret_cast<const ldg_vector_t *>(d_matrix)),
|
||||
matrix_values_stride_k(_matrix_values_stride_k / LdgVectorItems),
|
||||
guard(0),
|
||||
residue_guard(0)
|
||||
{
|
||||
// Compute block's starting coordinates in units of vectors
|
||||
int block_base_l = block_begin_item_coords.x / LdgVectorItems;
|
||||
int block_base_k = block_begin_item_coords.y;
|
||||
|
||||
// Compute a thread tiling of the block-wide tile
|
||||
int tid = threadIdx.x;
|
||||
thread_offset_l = tid % BlockLdgVectorsL;
|
||||
thread_offset_k = tid / BlockLdgVectorsL;
|
||||
|
||||
// Add the block and thread offsets to the source pointer
|
||||
ptr += (block_base_l + thread_offset_l) +
|
||||
(block_base_k + thread_offset_k) * matrix_values_stride_k;
|
||||
|
||||
// When AllowRaggedTiles support is enabled, compute a bit vector of guard
|
||||
// predicates
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
if (congruous)
|
||||
{
|
||||
request_iterations = (dim_k - start_k) / BlockItemsK;
|
||||
}
|
||||
else
|
||||
{
|
||||
request_iterations = (matrix_values_l - start_l) / BlockItemsL;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
|
||||
{
|
||||
int item = l_idx + k_idx * ThreadLdgVectorsL;
|
||||
|
||||
// Global vector L and K indices
|
||||
int vec_l = l_idx * BlockThreads;
|
||||
int vec_k = k_idx * BlockLdgVectorStrideK;
|
||||
|
||||
predicate_mask_t pred;
|
||||
predicate_mask_t residue_pred;
|
||||
|
||||
if (congruous)
|
||||
{
|
||||
pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < matrix_values_l) ? 1 : 0);
|
||||
residue_pred = ((block_base_k + thread_offset_k + vec_k < (dim_k % BlockItemsK)) ? 1 : 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
pred = ((block_base_k + thread_offset_k + vec_k < dim_k) ? 1 : 0);
|
||||
residue_pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < (matrix_values_l % BlockItemsL)) ? 1 : 0);
|
||||
}
|
||||
|
||||
// Update the guard and residue_guard word with predicate bits
|
||||
guard |= (pred << item);
|
||||
residue_guard |= (residue_pred << item);
|
||||
}
|
||||
}
|
||||
|
||||
// If there are zero full request iterations, compute the intersection
|
||||
// with the residue guard.
|
||||
if (!request_iterations)
|
||||
{
|
||||
guard &= residue_guard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Request the current block-wide tile from source memory
|
||||
*/
|
||||
inline __device__
|
||||
void request()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
|
||||
{
|
||||
int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
|
||||
bool pred = !AllowRaggedTiles || (guard & (predicate_mask_t(1) << load_idx));
|
||||
if (pred)
|
||||
{
|
||||
fetch[load_idx].load(
|
||||
ptr +
|
||||
(k_idx * BlockLdgVectorStrideK * matrix_values_stride_k) + (l_idx * BlockThreads));
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int elem_idx = 0; elem_idx < LdgVectorItems; ++elem_idx)
|
||||
{
|
||||
fetch[load_idx].buff[elem_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Advance to the next block-wide tile
|
||||
inline __device__
|
||||
void next()
|
||||
{
|
||||
if (congruous)
|
||||
{
|
||||
ptr += BlockItemsK * matrix_values_stride_k;
|
||||
}
|
||||
else
|
||||
{
|
||||
ptr += BlockLdgVectorsL;
|
||||
}
|
||||
|
||||
// Track number of full iterations to intersect with the residue guard predicates.
|
||||
if (AllowRaggedTiles)
|
||||
{
|
||||
--request_iterations;
|
||||
if (!request_iterations)
|
||||
{
|
||||
guard &= residue_guard;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Commit the values to the scratch tile to destination memory.
|
||||
template <int SmemStride>
|
||||
inline __device__
|
||||
void commit(value_t *scratch_tile)
|
||||
{
|
||||
static_assert(SmemStride % LdgVectorItems == 0,
|
||||
"SMEM stride must be divisible by the size of vector loads");
|
||||
|
||||
ldg_vector_t *smem_ptr = reinterpret_cast<ldg_vector_t *>(scratch_tile);
|
||||
smem_ptr += thread_offset_l + thread_offset_k * SmemStride / LdgVectorItems;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
|
||||
{
|
||||
int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
|
||||
|
||||
fetch[load_idx].store(smem_ptr +
|
||||
(k_idx * BlockLdgVectorStrideK * SmemStride / LdgVectorItems) +
|
||||
(l_idx * BlockThreads));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,677 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* A block-wide task abstraction for computing device-wide GEMM
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
#include "grid_raster.h"
|
||||
#include "block_loader.h"
|
||||
#include "k_split_control.h"
|
||||
#include "thread_accumulator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_task_policy
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Parameterizable tuning policy for \p block_task
|
||||
*
|
||||
* Once parameterized, \p block_task_policy provides the member constant
|
||||
* \p BlockThreads indicating to the required thread block size
|
||||
*/
|
||||
template <
|
||||
int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
int _ThreadItemsY, ///< Height in rows of a thread tile in C
|
||||
int _ThreadItemsX, ///< Width in columns of a thread tile in C
|
||||
bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
struct block_task_policy
|
||||
{
|
||||
enum
|
||||
{
|
||||
/// Height in rows of a block-wide tile in matrix C
|
||||
BlockItemsY = _BlockItemsY,
|
||||
|
||||
/// Width in columns of a block-wide tile in matrix C
|
||||
BlockItemsX = _BlockItemsX,
|
||||
|
||||
/// Height in rows of a thread tile in C
|
||||
ThreadItemsY = _ThreadItemsY,
|
||||
|
||||
/// Width in columns of a thread tile in C
|
||||
ThreadItemsX = _ThreadItemsX,
|
||||
|
||||
/// Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
BlockItemsK = _BlockItemsK,
|
||||
|
||||
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
UseDoubleScratchTiles = _UseDoubleScratchTiles,
|
||||
|
||||
/// Number of threads in each thread block (blockDim.x)
|
||||
BlockThreads = divide_assert<
|
||||
(BlockItemsY * BlockItemsX),
|
||||
(ThreadItemsY * ThreadItemsX)>::value,
|
||||
};
|
||||
|
||||
/// Strategy for enumerating \p block_task within an input matrix
|
||||
static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_task
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A block-wide task abstraction for computing device-wide GEMM
|
||||
*
|
||||
* Each thread_block is assigned a unique tile of output matrix C to compute by
|
||||
* consuming the corresponding stripes of the input matrices A and B.
|
||||
*/
|
||||
template <
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) for A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) for B operand
|
||||
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
|
||||
int LdgAlignC, ///< Alignment (in bytes) for C operand
|
||||
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
>
|
||||
struct block_task
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of threads in each thread block (blockDim.x)
|
||||
BlockThreads = block_task_policy_t::BlockThreads,
|
||||
|
||||
/// Extent of thread tile in value_t along M-axis
|
||||
ThreadItemsY = block_task_policy_t::ThreadItemsY,
|
||||
|
||||
/// Extent of thread tile in value_t along N-axis
|
||||
ThreadItemsX = block_task_policy_t::ThreadItemsX,
|
||||
};
|
||||
|
||||
/// Accumulator type
|
||||
typedef thread_accumulator<
|
||||
ThreadItemsY,
|
||||
ThreadItemsX,
|
||||
value_t,
|
||||
accum_t>
|
||||
thread_accumulator_t;
|
||||
|
||||
/// Dot-product vector type along the K-axis (e.g, uchar4 when using IDP4A)
|
||||
typedef typename thread_accumulator_t::dp_vector_t dp_vector_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Whether this is a small, latency-bound tile
|
||||
IsSmallTile = (ThreadItemsY < 4) && (ThreadItemsX < 4),
|
||||
|
||||
/// Number of value_t in dp_vector_t
|
||||
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
|
||||
|
||||
/// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
|
||||
BlockItemsY = block_task_policy_t::BlockItemsY,
|
||||
|
||||
/// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
|
||||
BlockItemsX = block_task_policy_t::BlockItemsX,
|
||||
|
||||
/// Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
BlockItemsK = block_task_policy_t::BlockItemsK,
|
||||
|
||||
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
|
||||
|
||||
/// Extent of block-wide A|B tiles in dp_vector_t along the K-axis
|
||||
BlockDpVectorsK = divide_assert<BlockItemsK, DpVectorItems>::value,
|
||||
|
||||
/// Number of dp_vector_t along M-axis that can be read in a single LDS from the shared A-tile (up to 128b if more than one value_t)
|
||||
LdsVectorDpVectorsA = __NV_STD_MIN(
|
||||
ThreadItemsY,
|
||||
__NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
|
||||
|
||||
/// Number of dp_vector_t along N-axis that can be read in a single LDS from the shared B-tile (up to 128b if more than one value_t)
|
||||
LdsVectorDpVectorsB = __NV_STD_MIN(
|
||||
ThreadItemsX,
|
||||
__NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
|
||||
|
||||
/// Number of strip-mined LDS vector reads from shared A-tile
|
||||
ThreadLdsVectorsA = divide_assert<ThreadItemsY, LdsVectorDpVectorsA>::value,
|
||||
|
||||
/// Number of strip-mined LDS vector reads from shared B-tile
|
||||
ThreadLdsVectorsB = divide_assert<ThreadItemsX, LdsVectorDpVectorsB>::value,
|
||||
|
||||
/// Number of elements in one LDG/STG vector of C-tile
|
||||
ThreadLdgVectorSizeC = __NV_STD_MIN(LdgAlignC, 16) / (sizeof(accum_t)),
|
||||
|
||||
/// Number of threads in warp
|
||||
WarpThreads = 32,
|
||||
|
||||
/// Extent of warp in threads along the M-axis
|
||||
WarpThreadsY = (BlockItemsY > BlockItemsX) ? 8 : 4,
|
||||
|
||||
/// Extent of warp in threads along the N-axis
|
||||
WarpThreadsX = divide_assert<WarpThreads, WarpThreadsY>::value,
|
||||
|
||||
/// Extent of warp-wide tile in items along the M-axis
|
||||
WarpItemsY = WarpThreadsY * ThreadItemsY,
|
||||
|
||||
/// Extent of warp-wide tile in items along the N-axis
|
||||
WarpItemsX = WarpThreadsX * ThreadItemsX,
|
||||
|
||||
/// Extent of block in warps along M-axis
|
||||
BlockWarpsY = divide_assert<BlockItemsY, WarpItemsY>::value,
|
||||
|
||||
/// Extent of block in warps along N-axis
|
||||
BlockWarpsX = divide_assert<BlockItemsX, WarpItemsX>::value,
|
||||
};
|
||||
|
||||
/// Load-from-shared data movement type for A-tile, coarsened by LdsVectorDpVectorsA
|
||||
typedef io_vector<dp_vector_t, LdsVectorDpVectorsA> lds_vector_a_t;
|
||||
|
||||
/// Load-from-shared data movement type for B-tile, coarsened by LdsVectorDpVectorsB
|
||||
typedef io_vector<dp_vector_t, LdsVectorDpVectorsB> lds_vector_b_t;
|
||||
|
||||
/// Thread block rasterization helper type
|
||||
typedef grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
block_task_policy_t::RasterStrategy>
|
||||
grid_raster_t;
|
||||
|
||||
|
||||
/// Tile loader type for matrix A
|
||||
typedef block_loader<
|
||||
BlockThreads, // BlockThreads
|
||||
BlockDpVectorsK, // BlockDpVectorsK
|
||||
BlockItemsY, // BlockItemsL
|
||||
value_t, // value_t
|
||||
LdgAlignA, // MatrixAlignBytes
|
||||
AllowRaggedTiles, // AllowRaggedTiles
|
||||
dp_vector_t, // dp_vector_t
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
|
||||
load_algorithm::CongruousCopy :
|
||||
load_algorithm::CrosswiseCopy>
|
||||
block_loader_a_t;
|
||||
|
||||
|
||||
/// Tile loader type for matrix B
|
||||
typedef block_loader<
|
||||
BlockThreads, // BlockThreads
|
||||
BlockDpVectorsK, // BlockDpVectorsK
|
||||
BlockItemsX, // BlockItemsL
|
||||
value_t, // value_t
|
||||
LdgAlignB, // MatrixAlignBytes
|
||||
AllowRaggedTiles, // AllowRaggedTiles
|
||||
dp_vector_t, // dp_vector_t
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
|
||||
load_algorithm::CrosswiseCopy :
|
||||
load_algorithm::CongruousCopy>
|
||||
block_loader_b_t;
|
||||
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of value_t to pad the end of each row of the shared A-tile
|
||||
PadItemsA = (TransformA == matrix_transform_t::NonTranspose) ?
|
||||
__NV_STD_MAX(LdsVectorDpVectorsA, block_loader_a_t::AlignmentDpVectorsL) :
|
||||
LdsVectorDpVectorsA,
|
||||
|
||||
/// Number of value_t to pad the end of each row of the shared B-tile
|
||||
PadItemsB = (TransformB == matrix_transform_t::NonTranspose) ?
|
||||
LdsVectorDpVectorsB :
|
||||
__NV_STD_MAX(LdsVectorDpVectorsB, block_loader_b_t::AlignmentDpVectorsL),
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory layout for a prefetch page
|
||||
struct page_storage_t
|
||||
{
|
||||
/// Tile of A
|
||||
dp_vector_t __align__(16) block_a[BlockDpVectorsK][BlockItemsY + PadItemsA];
|
||||
|
||||
/// Tile of B
|
||||
dp_vector_t __align__(16) block_b[BlockDpVectorsK][BlockItemsX + PadItemsB];
|
||||
};
|
||||
|
||||
|
||||
/// Shared memory layout for scratch storage
|
||||
struct scratch_storage_t
|
||||
{
|
||||
/// Prefetch pages
|
||||
page_storage_t pages[UseDoubleScratchTiles ? 2 : 1];
|
||||
|
||||
/// Accumulator shared scratch
|
||||
typename thread_accumulator_t::scratch_storage_t accum_scratch;
|
||||
};
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
// Ensure we have at least two unrolled innermost loop iterations (one to prefetch
|
||||
// the next global tile and then one to prefetch the first strip of it from shared)
|
||||
static_assert ((BlockDpVectorsK >= 2), "BlockDpVectorsK must be >= 2.");
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Scratch storage reference
|
||||
scratch_storage_t *scratch;
|
||||
|
||||
/// Which page of scratch tiles we're currently reading from
|
||||
int page_idx;
|
||||
|
||||
/// Pointer to matrix C
|
||||
accum_t *d_c;
|
||||
|
||||
/// Epilogue operation applied to update matrix C
|
||||
epilogue_op_t epilogue_op;
|
||||
|
||||
/// Matrix height in rows of trans_op(A) and C
|
||||
int dim_m;
|
||||
|
||||
/// Matrix width in columns of trans_op(B) and C
|
||||
int dim_n;
|
||||
|
||||
/// Control for inter-block k-splitting
|
||||
k_split_control k_split;
|
||||
|
||||
/// Thread block's base value_t coordinates (m, n) in matrix C
|
||||
grid_raster_t grid_raster;
|
||||
|
||||
/// Thread block's current coordinate (k) within A|B matrices
|
||||
int block_item_coords_k;
|
||||
|
||||
/// Thread block's ending coordinate (k) within A|B matrices (one-past)
|
||||
int block_end_item_k;
|
||||
|
||||
/// Warp's coordinates (x, y) in thread block
|
||||
int2 block_warp_coords;
|
||||
|
||||
/// Thread's coordinates (x, y) in warp
|
||||
int2 warp_thread_coords;
|
||||
|
||||
/// Thread's base item offset within strip of A tile
|
||||
int thread_strip_offset_a;
|
||||
|
||||
/// Thread's base item offset within strip of B tile
|
||||
int thread_strip_offset_b;
|
||||
|
||||
/// Thread's active-k/prefetch-k slices from shared A tile
|
||||
lds_vector_a_t local_slices_a[2][ThreadLdsVectorsA];
|
||||
|
||||
/// Thread's active-k/prefetch-k slices from shared B tile
|
||||
lds_vector_b_t local_slices_b[2][ThreadLdsVectorsB];
|
||||
|
||||
/// A tile loader
|
||||
block_loader_a_t loader_a;
|
||||
|
||||
/// B tile loader
|
||||
block_loader_b_t loader_b;
|
||||
|
||||
/// C tile accumulator
|
||||
thread_accumulator_t accumulator;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Coordinate system helpers
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the warp's coordinates (x, y) in thread block
|
||||
inline __device__
|
||||
int2 warp_coords()
|
||||
{
|
||||
int warp_id = threadIdx.x / WarpThreads;
|
||||
return make_int2(
|
||||
warp_id % BlockWarpsX,
|
||||
warp_id / BlockWarpsX);
|
||||
}
|
||||
|
||||
|
||||
/// Compute the thread's lane-coordinates (x, y) in warp
|
||||
inline __device__
|
||||
int2 thread_coords()
|
||||
{
|
||||
int lane_id = threadIdx.x % WarpThreads;
|
||||
|
||||
// Maxwell+ mapping of threads within a 2D warp for maximal LDS bandwidth
|
||||
return make_int2(
|
||||
lane_id / WarpThreadsY,
|
||||
lane_id % WarpThreadsY);
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_task(
|
||||
scratch_storage_t *scratch,
|
||||
value_t *d_a,
|
||||
value_t *d_b,
|
||||
accum_t *d_c,
|
||||
epilogue_op_t epilogue_op,
|
||||
int dim_m,
|
||||
int dim_n,
|
||||
int dim_k,
|
||||
k_split_control k_split)
|
||||
:
|
||||
scratch(scratch),
|
||||
page_idx(0),
|
||||
d_c(d_c),
|
||||
epilogue_op(epilogue_op),
|
||||
dim_m(dim_m),
|
||||
dim_n(dim_n),
|
||||
k_split(k_split),
|
||||
block_item_coords_k(k_split.block_begin_item_k()),
|
||||
block_end_item_k(k_split.block_end_item_k(dim_k)),
|
||||
block_warp_coords(warp_coords()),
|
||||
warp_thread_coords(thread_coords()),
|
||||
thread_strip_offset_a((warp_thread_coords.y * LdsVectorDpVectorsA) + (block_warp_coords.y * WarpItemsY)),
|
||||
thread_strip_offset_b((warp_thread_coords.x * LdsVectorDpVectorsB) + (block_warp_coords.x * WarpItemsX)),
|
||||
|
||||
loader_a(
|
||||
d_a, // d_matrix
|
||||
dim_m, // matrix_values_l
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? dim_m : 1, // matrix_values_stride_k
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? 1 : dim_k, // matrix_values_stride_l
|
||||
make_int2( // block_begin_item_coords
|
||||
grid_raster.block_item_coords.y,
|
||||
block_item_coords_k),
|
||||
block_end_item_k), // block_end_item_k
|
||||
|
||||
loader_b(
|
||||
d_b, // d_matrix
|
||||
dim_n, // matrix_values_l
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? 1 : dim_n, // matrix_values_stride_k
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? dim_k : 1, // matrix_values_stride_l
|
||||
make_int2( // block_begin_item_coords
|
||||
grid_raster.block_item_coords.x,
|
||||
block_item_coords_k),
|
||||
block_end_item_k), // block_end_item_k
|
||||
|
||||
accumulator(scratch->accum_scratch)
|
||||
{}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Prefetching utility methods
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
|
||||
*/
|
||||
inline __device__ void request_local_prefetch(
|
||||
lds_vector_a_t (&slice_a)[ThreadLdsVectorsA], ///< Slice from A
|
||||
lds_vector_b_t (&slice_b)[ThreadLdsVectorsB], ///< Slice from B
|
||||
int tile_offset_k)
|
||||
{
|
||||
// Load B strip
|
||||
for (int i = 0; i < ThreadLdsVectorsB; ++i)
|
||||
{
|
||||
slice_b[i].load(
|
||||
&scratch->pages[page_idx].block_b[tile_offset_k][thread_strip_offset_b + (i * WarpThreadsX * LdsVectorDpVectorsB)]);
|
||||
}
|
||||
|
||||
// Load A strip
|
||||
for (int i = 0; i < ThreadLdsVectorsA; ++i)
|
||||
{
|
||||
slice_a[i].load(
|
||||
&scratch->pages[page_idx].block_a[tile_offset_k][thread_strip_offset_a + (i * WarpThreadsY * LdsVectorDpVectorsA)]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Epilogue
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Performs the GEMM epilogue:
|
||||
* - Applies the scalar multipliers and addends to the accumulators
|
||||
* - Write the result to the output matrix
|
||||
*/
|
||||
__forceinline__ __device__
|
||||
void epilogue()
|
||||
{
|
||||
// Wait for predecessor thread block(s) to produce block-wide tile of
|
||||
// exclsuive partial-sums
|
||||
k_split.wait();
|
||||
|
||||
// Configure epilogue as to whether the thread block is a secondary
|
||||
// accumulator in an inter-block k-splitting scheme
|
||||
if (k_split.is_secondary_accumulator())
|
||||
epilogue_op.set_secondary_accumulator();
|
||||
|
||||
// Whether the addend from C needs loading
|
||||
bool must_init_addend = epilogue_op.must_init_addend();
|
||||
|
||||
#pragma unroll
|
||||
for (int x = 0; x < ThreadItemsX; ++x)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < ThreadItemsY; y += LdsVectorDpVectorsA)
|
||||
{
|
||||
int thread_strip_b = x / LdsVectorDpVectorsB;
|
||||
int thread_strip_a = y / LdsVectorDpVectorsA;
|
||||
|
||||
int thread_item_coords_tile_x = thread_strip_offset_b + (thread_strip_b * WarpThreadsX * LdsVectorDpVectorsB) + (x % LdsVectorDpVectorsB);
|
||||
int thread_item_coords_tile_y = thread_strip_offset_a + (thread_strip_a * WarpThreadsY * LdsVectorDpVectorsA) + (y % LdsVectorDpVectorsA);
|
||||
|
||||
int c_idx = (grid_raster.block_item_coords.x + thread_item_coords_tile_x) * dim_m +
|
||||
grid_raster.block_item_coords.y + thread_item_coords_tile_y;
|
||||
|
||||
accum_t *my_c = d_c + c_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < LdsVectorDpVectorsA; ++i)
|
||||
{
|
||||
accum_t c_slice = accum_t(0);
|
||||
accum_t *c_ptr = my_c + i;
|
||||
|
||||
if ((grid_raster.block_item_coords.x + thread_item_coords_tile_x) < dim_n &&
|
||||
(grid_raster.block_item_coords.y + thread_item_coords_tile_y + i) < dim_m)
|
||||
{
|
||||
if (must_init_addend)
|
||||
{
|
||||
ldg_cg(c_slice, c_ptr);
|
||||
}
|
||||
|
||||
c_slice = epilogue_op(accumulator.get(x, y + i), c_slice, c_idx + i);
|
||||
|
||||
stg_cg(c_ptr, c_slice);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal k-split successor thread_block that we have produced our block-wide
|
||||
// tile of inclusive partial-sums
|
||||
k_split.signal();
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Tile consumption
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Consume a tile of A and B each
|
||||
*/
|
||||
template <bool DoGlobalPrefetch>
|
||||
__forceinline__ __device__
|
||||
void consume_tile()
|
||||
{
|
||||
// Unroll BlockDpVectorsK iterations of outer-product accumulations
|
||||
#pragma unroll
|
||||
for (int tile_offset_k = 0; tile_offset_k < BlockDpVectorsK; tile_offset_k += 1)
|
||||
{
|
||||
// Last strip commits global prefetch for next tile
|
||||
if ((tile_offset_k == BlockDpVectorsK - 1) && DoGlobalPrefetch)
|
||||
{
|
||||
// If not using two pages of scratch tiles, protect the above prefetch loads from the committing writes below
|
||||
if (!UseDoubleScratchTiles)
|
||||
__syncthreads();
|
||||
|
||||
// If using two pages of scratch tiles, switch to next page before writing
|
||||
if (UseDoubleScratchTiles)
|
||||
{
|
||||
page_idx = (page_idx ? 0 : 1);
|
||||
}
|
||||
|
||||
// Commit global prefetch data to scratch page
|
||||
loader_a.commit(scratch->pages[page_idx].block_a);
|
||||
loader_b.commit(scratch->pages[page_idx].block_b);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Request local prefetch for next strip
|
||||
request_local_prefetch(
|
||||
local_slices_a[(tile_offset_k + 1) % 2],
|
||||
local_slices_b[(tile_offset_k + 1) % 2],
|
||||
(tile_offset_k + 1) % BlockDpVectorsK);
|
||||
|
||||
// Request global prefetch for next tile on first strip
|
||||
if ((tile_offset_k == 0) && DoGlobalPrefetch)
|
||||
{
|
||||
loader_b.request();
|
||||
loader_b.next();
|
||||
loader_a.request();
|
||||
loader_a.next();
|
||||
}
|
||||
|
||||
// Cast strip-mined loads to contiguous array of dp_vector_t
|
||||
typedef dp_vector_t thread_tile_a_t[ThreadLdsVectorsA * LdsVectorDpVectorsA];
|
||||
typedef dp_vector_t thread_tile_b_t[ThreadLdsVectorsB * LdsVectorDpVectorsB];
|
||||
thread_tile_a_t &thread_tile_a = reinterpret_cast<thread_tile_a_t&>(local_slices_a[(tile_offset_k) % 2]);
|
||||
thread_tile_b_t &thread_tile_b = reinterpret_cast<thread_tile_b_t&>(local_slices_b[(tile_offset_k) % 2]);
|
||||
|
||||
// Accumulate this dp-stripe product
|
||||
accumulator.multiply_accumulate(thread_tile_a, thread_tile_b);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// GEMM API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Compute GEMM
|
||||
*/
|
||||
__forceinline__ __device__
|
||||
void run()
|
||||
{
|
||||
// Quit if the thread block is fully out-of-bounds
|
||||
if (grid_raster.is_block_oob(dim_m, dim_n))
|
||||
{
|
||||
asm volatile("exit;");
|
||||
}
|
||||
|
||||
// Request global prefetch of first tile
|
||||
loader_a.request();
|
||||
loader_a.next();
|
||||
loader_b.request();
|
||||
loader_b.next();
|
||||
|
||||
// Commit global prefetch of first tile to shared memory
|
||||
loader_a.commit(scratch->pages[page_idx].block_a);
|
||||
loader_b.commit(scratch->pages[page_idx].block_b);
|
||||
|
||||
// Advance to next A,B tiles in K-axis
|
||||
block_item_coords_k += BlockItemsK;
|
||||
|
||||
// Synchronize shared tiles and prepared accumulator
|
||||
__syncthreads();
|
||||
|
||||
// Initialize thread's slice of accumulators
|
||||
accumulator.init();
|
||||
|
||||
// Request first iteration of local prefetch strips
|
||||
request_local_prefetch(
|
||||
local_slices_a[0],
|
||||
local_slices_b[0],
|
||||
0);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Consume tiles in A and B along the K-axis (all but last tile)
|
||||
#pragma unroll 1
|
||||
while (block_item_coords_k < block_end_item_k)
|
||||
{
|
||||
consume_tile<true>();
|
||||
|
||||
// Advance to next A,B tiles in K-axis
|
||||
block_item_coords_k += BlockItemsK;
|
||||
}
|
||||
|
||||
// Consume last tile
|
||||
consume_tile<false>();
|
||||
|
||||
//
|
||||
// Eplilogue
|
||||
//
|
||||
|
||||
epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,767 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \file
|
||||
* A block-wide task abstraction for computing device-wide GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Compiler guard conditional to avoid compilation errors on versions of CUDA that
|
||||
// do not support the WMMA API.
|
||||
#if defined (WMMA)
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
#include "grid_raster.h"
|
||||
#include "block_loader.h"
|
||||
#include "block_loader_wmma.h"
|
||||
#include "wmma_accumulator.h"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_task_wmma_policy
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Parameterizable tuning policy for block-wide WMMA GEMM tasks
|
||||
*
|
||||
* Once parameterized, \p block_task_policy provides the member constant
|
||||
* \p BlockThreads indicating to the required thread block size
|
||||
*/
|
||||
template <
|
||||
int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
int _WarpItemsY, ///< Height in rows of a Warp tile's accumulators
|
||||
int _WarpItemsX, ///< Width in columns of a Warp tile's accumulators
|
||||
int _WmmaItemsY, ///< Height in rows of a discrete WMMA block's accumulators
|
||||
int _WmmaItemsX, ///< Width in columns of a discrete WMMA block's accumulators
|
||||
int _WmmaItemsK, ///< Depth of each discrete WMMA block
|
||||
bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
struct block_task_wmma_policy
|
||||
{
|
||||
/// Strategy for enumerating \p block_task within an input matrix
|
||||
static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Height in rows of a block-wide tile in matrix C
|
||||
BlockItemsY = _BlockItemsY,
|
||||
|
||||
/// Width in columns of a block-wide tile in matrix C
|
||||
BlockItemsX = _BlockItemsX,
|
||||
|
||||
/// Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
BlockItemsK = _BlockItemsK,
|
||||
|
||||
/// Height in rows of a Warp tile's accumulators
|
||||
WarpItemsX = _WarpItemsX,
|
||||
|
||||
/// Width in columns of a Warp tile's accumulators
|
||||
WarpItemsY = _WarpItemsY,
|
||||
|
||||
/// Width in columns of a discrete WMMA block's accumulators
|
||||
WmmaItemsX = _WmmaItemsX,
|
||||
|
||||
/// Height in rows of a discrete WMMA block's accumulators
|
||||
WmmaItemsY = _WmmaItemsY,
|
||||
|
||||
/// Depth of each discrete WMMA block
|
||||
WmmaItemsK = _WmmaItemsK,
|
||||
|
||||
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
UseDoubleScratchTiles = _UseDoubleScratchTiles,
|
||||
|
||||
|
||||
//
|
||||
// Derived quantities
|
||||
//
|
||||
|
||||
/// Machine warp size
|
||||
WarpThreads = 32,
|
||||
|
||||
/// Number of WMMA operations in the height dimension
|
||||
WmmaBlocksY = divide_assert<WarpItemsY, WmmaItemsY>::value,
|
||||
|
||||
/// Number of WMMA operations in the height dimension
|
||||
WmmaBlocksX = divide_assert<WarpItemsX, WmmaItemsX>::value,
|
||||
|
||||
/// Number of warps in each thread block
|
||||
BlockWarps = divide_assert<BlockItemsY * BlockItemsX, WarpItemsX * WarpItemsY>::value,
|
||||
|
||||
/// Number of threads in each thread block (blockDim.x)
|
||||
BlockThreads = BlockWarps * WarpThreads,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* block_task_wmma
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A block-wide task abstraction for computing device-wide GEMM
|
||||
*
|
||||
* Each thread_block is assigned a unique tile of output matrix C to compute by
|
||||
* consuming the corresponding stripes of the input matrices A and B.
|
||||
*/
|
||||
template <
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) for A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) for B operand
|
||||
typename epilogue_op_t, ///< Epilogue operation to update matrix C
|
||||
int LdgAlignC, ///< Alignment (in bytes) for C operand
|
||||
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
|
||||
>
|
||||
struct block_task_wmma
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of threads in each thread block (blockDim.x)
|
||||
BlockThreads = block_task_policy_t::BlockThreads,
|
||||
|
||||
/// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
|
||||
BlockItemsY = block_task_policy_t::BlockItemsY,
|
||||
|
||||
/// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
|
||||
BlockItemsX = block_task_policy_t::BlockItemsX,
|
||||
|
||||
/// Extent of block-wide A|B tiles in value_t along the K-axis
|
||||
BlockItemsK = block_task_policy_t::BlockItemsK,
|
||||
|
||||
/// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
|
||||
WarpItemsY = block_task_policy_t::WarpItemsY,
|
||||
|
||||
/// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
|
||||
WarpItemsX = block_task_policy_t::WarpItemsX,
|
||||
|
||||
/// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
|
||||
WmmaItemsY = block_task_policy_t::WmmaItemsY,
|
||||
|
||||
/// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
|
||||
WmmaItemsX = block_task_policy_t::WmmaItemsX,
|
||||
|
||||
/// Extent of warp-wide A|B-tiles in value_t along K-axis
|
||||
WmmaItemsK = block_task_policy_t::WmmaItemsK,
|
||||
|
||||
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
|
||||
UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
|
||||
|
||||
/// Number of threads in warp
|
||||
WarpThreads = block_task_policy_t::WarpThreads,
|
||||
|
||||
/// Number of warps participating
|
||||
BlockWarps = block_task_policy_t::BlockWarps,
|
||||
|
||||
/// Extent of block in warps along M-axis
|
||||
BlockWarpsY = divide_assert<BlockItemsY, WarpItemsY>::value,
|
||||
|
||||
/// Extent of block in warps along N-axis
|
||||
BlockWarpsX = divide_assert<BlockItemsX, WarpItemsX>::value,
|
||||
|
||||
/// Number of MMA unrolls
|
||||
WmmaUnrollCount = divide_assert<BlockItemsK, WmmaItemsK>::value,
|
||||
|
||||
/// True if the A matrix layout is column major (K is the strided dimension)
|
||||
IsLayoutCongruousA = (TransformA == matrix_transform_t::NonTranspose),
|
||||
|
||||
/// True if the B matrix layout is row mayor (K is the strided dimension)
|
||||
IsLayoutCongruousB = (TransformB == matrix_transform_t::Transpose),
|
||||
|
||||
};
|
||||
|
||||
/// WMMA may support unique types for A and B, so plan ahead for this
|
||||
typedef value_t value_a_t;
|
||||
|
||||
/// WMMA may support unique types for A and B, so plan ahead for this
|
||||
typedef value_t value_b_t;
|
||||
|
||||
/// WMMA accumulator type
|
||||
typedef wmma_accumulator<
|
||||
WarpItemsY,
|
||||
WarpItemsX,
|
||||
WmmaItemsY,
|
||||
WmmaItemsX,
|
||||
WmmaItemsK,
|
||||
value_a_t,
|
||||
value_b_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
TransformB>
|
||||
accumulator_t;
|
||||
|
||||
/// Thread block rasterization helper type
|
||||
typedef grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
block_task_policy_t::RasterStrategy>
|
||||
grid_raster_t;
|
||||
|
||||
/// Tile loader type for matrix A
|
||||
typedef block_loader_wmma<
|
||||
IsLayoutCongruousA,
|
||||
BlockThreads,
|
||||
(IsLayoutCongruousA ? BlockItemsY : BlockItemsK),
|
||||
(IsLayoutCongruousA ? BlockItemsK : BlockItemsY),
|
||||
value_a_t,
|
||||
LdgAlignA,
|
||||
AllowRaggedTiles>
|
||||
block_loader_a_t;
|
||||
|
||||
/// Tile loader type for matrix A
|
||||
typedef block_loader_wmma<
|
||||
IsLayoutCongruousB,
|
||||
BlockThreads,
|
||||
(IsLayoutCongruousB ? BlockItemsX : BlockItemsK),
|
||||
(IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
|
||||
value_b_t,
|
||||
LdgAlignB,
|
||||
AllowRaggedTiles>
|
||||
block_loader_b_t;
|
||||
|
||||
/// Type alias for matrix A fragment type
|
||||
typedef typename accumulator_t::fragment_a_t fragment_a_t;
|
||||
|
||||
/// Type alias for matrix B fragment type
|
||||
typedef typename accumulator_t::fragment_b_t fragment_b_t;
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of fragments from A matrix
|
||||
WmmaBlocksY = accumulator_t::WmmaBlocksY,
|
||||
|
||||
/// Number of fragments from B matrix
|
||||
WmmaBlocksX = accumulator_t::WmmaBlocksX,
|
||||
|
||||
/// Number of value_t to pad the outer dimension of the shared A-tile
|
||||
PadItemsA = 16,
|
||||
|
||||
/// Number of value_t to pad the outer dimension of the shared B-tile
|
||||
PadItemsB = 16,
|
||||
|
||||
/// Leading dimension of A matrix tile
|
||||
LdmSmemA = (IsLayoutCongruousA ? BlockItemsY: BlockItemsK) + PadItemsA,
|
||||
|
||||
/// Leading dimension of A matrix tile
|
||||
StridedSmemA = (IsLayoutCongruousA ? BlockItemsK : BlockItemsY ),
|
||||
|
||||
/// Leading dimension of B matrix tile
|
||||
LdmSmemB = (IsLayoutCongruousB? BlockItemsX : BlockItemsK) + PadItemsB,
|
||||
|
||||
StridedSmemB = (IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
|
||||
};
|
||||
|
||||
/// Shared memory layout for a prefetch page
|
||||
struct page_storage_t
|
||||
{
|
||||
/// Tile of A
|
||||
value_a_t __align__(16) block_a[StridedSmemA][LdmSmemA];
|
||||
|
||||
/// Tile of B
|
||||
value_b_t __align__(16) block_b[StridedSmemB][LdmSmemB];
|
||||
};
|
||||
|
||||
/// Shared memory layout for scratch storage
|
||||
struct scratch_storage_t
|
||||
{
|
||||
union
|
||||
{
|
||||
/// Prefetch pages
|
||||
uninitialized<page_storage_t> pages[UseDoubleScratchTiles ? 2 : 1];
|
||||
|
||||
/// Scratch storage for warps
|
||||
accum_t epilogue[BlockWarps][WmmaItemsX * WmmaItemsY];
|
||||
};
|
||||
};
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Assert assumptions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
// Ensure we have at least two unrolled innermost loop iterations (one to prefetch
|
||||
// the next global tile and then one to prefetch the first strip of it from shared)
|
||||
static_assert ((BlockItemsK >= 2), "BlockItemsK must be >= 2.");
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Scratch storage reference
|
||||
scratch_storage_t *scratch;
|
||||
|
||||
/// Which page of scratch tiles we're currently reading from
|
||||
int page_idx;
|
||||
|
||||
/// Pointer to matrix C
|
||||
accum_t *d_c;
|
||||
|
||||
/// Epilogue operation applied to update matrix C
|
||||
epilogue_op_t epilogue_op;
|
||||
|
||||
/// Matrix height in rows of trans_op(A) and C
|
||||
int dim_m;
|
||||
|
||||
/// Matrix width in columns of trans_op(B) and C
|
||||
int dim_n;
|
||||
|
||||
/// Control for inter-block k-splitting
|
||||
k_split_control k_split;
|
||||
|
||||
/// Thread block's base value_t coordinates (m, n) in matrix C
|
||||
grid_raster_t grid_raster;
|
||||
|
||||
/// Thread block's current coordinate (k) within A|B matrices
|
||||
int block_item_coords_k;
|
||||
|
||||
/// Thread block's ending coordinate (k) within A|B matrices (one-past)
|
||||
int block_end_item_k;
|
||||
|
||||
/// Warp's coordinates (x, y) in thread block
|
||||
int2 block_warp_item_coords;
|
||||
|
||||
/// A tile loader
|
||||
block_loader_a_t loader_a;
|
||||
|
||||
/// B tile loader
|
||||
block_loader_b_t loader_b;
|
||||
|
||||
/// Thread's active-k/prefetch-k slices from shared A tile
|
||||
fragment_a_t local_slices_a[2][WmmaBlocksY];
|
||||
|
||||
/// Thread's active-k/prefetch-k slices from shared B tile
|
||||
fragment_b_t local_slices_b[2][WmmaBlocksX];
|
||||
|
||||
/// Accumulator tile
|
||||
accumulator_t accumulator;
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Coordinate system helpers
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the warp's item-coordinates (x, y) in thread block
|
||||
inline __device__
|
||||
int2 warp_item_coords()
|
||||
{
|
||||
int warp_id = threadIdx.x / WarpThreads;
|
||||
|
||||
return make_int2(
|
||||
(warp_id / BlockWarpsY) * WarpItemsX,
|
||||
(warp_id % BlockWarpsY) * WarpItemsY);
|
||||
}
|
||||
|
||||
/// Compute the thread block's base item-coordinates in matrix A
|
||||
inline __device__
|
||||
int2 a_block_item_coords()
|
||||
{
|
||||
if (TransformA == matrix_transform_t::NonTranspose)
|
||||
{
|
||||
return make_int2(grid_raster.block_item_coords.y, block_item_coords_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_int2(block_item_coords_k, grid_raster.block_item_coords.y);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the thread block's base item-coordinates in matrix B
|
||||
inline __device__
|
||||
int2 b_block_item_coords()
|
||||
{
|
||||
if (TransformB == matrix_transform_t::Transpose)
|
||||
{
|
||||
return make_int2(grid_raster.block_item_coords.x, block_item_coords_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_int2(block_item_coords_k, grid_raster.block_item_coords.x);
|
||||
}
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
block_task_wmma(
|
||||
scratch_storage_t *scratch,
|
||||
value_t *d_a,
|
||||
value_t *d_b,
|
||||
accum_t *d_c,
|
||||
epilogue_op_t epilogue_op,
|
||||
int dim_m,
|
||||
int dim_n,
|
||||
int dim_k,
|
||||
k_split_control k_split)
|
||||
:
|
||||
scratch(scratch),
|
||||
page_idx(0),
|
||||
d_c(d_c),
|
||||
epilogue_op(epilogue_op),
|
||||
dim_m(dim_m),
|
||||
dim_n(dim_n),
|
||||
k_split(k_split),
|
||||
block_item_coords_k(k_split.block_begin_item_k()),
|
||||
block_end_item_k(k_split.block_end_item_k(dim_k)),
|
||||
block_warp_item_coords(warp_item_coords()),
|
||||
|
||||
loader_a(
|
||||
reinterpret_cast<value_a_t const *>(d_a),
|
||||
(IsLayoutCongruousA ? dim_m : block_end_item_k),
|
||||
(IsLayoutCongruousA ? 0 : block_item_coords_k),
|
||||
(IsLayoutCongruousA ? block_end_item_k : dim_m),
|
||||
(IsLayoutCongruousA ? dim_m : dim_k),
|
||||
(IsLayoutCongruousA ? block_item_coords_k : 0),
|
||||
a_block_item_coords()),
|
||||
|
||||
loader_b(
|
||||
reinterpret_cast<value_b_t const *>(d_b),
|
||||
(IsLayoutCongruousB ? dim_n : block_end_item_k),
|
||||
(IsLayoutCongruousB ? 0 : block_item_coords_k),
|
||||
(IsLayoutCongruousB ? block_end_item_k : dim_n),
|
||||
(IsLayoutCongruousB ? dim_n : dim_k),
|
||||
(IsLayoutCongruousB ? block_item_coords_k : 0),
|
||||
b_block_item_coords())
|
||||
{}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Prefetching utility methods
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
|
||||
*/
|
||||
inline __device__ void request_local_prefetch(
|
||||
fragment_a_t local_slices_a[WmmaBlocksY], ///< Slice from A
|
||||
fragment_b_t local_slices_b[WmmaBlocksX], ///< Slice from B
|
||||
int tile_offset_k)
|
||||
{
|
||||
value_b_t const *smem_A_base = &scratch->pages[page_idx].alias().block_a[0][0];
|
||||
value_b_t const *smem_B_base = &scratch->pages[page_idx].alias().block_b[0][0];
|
||||
|
||||
int constexpr kstride_a = (IsLayoutCongruousA ? LdmSmemA : 1);
|
||||
int constexpr lstride_a = (IsLayoutCongruousA ? 1 : LdmSmemA);
|
||||
|
||||
int constexpr kstride_b = (IsLayoutCongruousB ? LdmSmemB : 1);
|
||||
int constexpr lstride_b = (IsLayoutCongruousB ? 1 : LdmSmemB);
|
||||
|
||||
// Load B strip
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WmmaBlocksX; ++i)
|
||||
{
|
||||
value_b_t const *smem_B_ptr =
|
||||
&smem_B_base[tile_offset_k * kstride_b + (block_warp_item_coords.x + WmmaItemsX * i) * lstride_b];
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(local_slices_b[i], smem_B_ptr, LdmSmemB);
|
||||
}
|
||||
|
||||
// Load A strip
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WmmaBlocksY; ++i)
|
||||
{
|
||||
value_a_t const *smem_A_ptr =
|
||||
&smem_A_base[tile_offset_k * kstride_a + (block_warp_item_coords.y + WmmaItemsY * i) * lstride_a];
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(local_slices_a[i], smem_A_ptr, LdmSmemA);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Epilogue
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Performs the GEMM epilogue:
|
||||
* - Applies the scalar multipliers and addends to the accumulators
|
||||
* - Write the result to the output matrix
|
||||
*/
|
||||
inline __device__ void epilogue()
|
||||
{
|
||||
// Wait for predecessor thread block(s) to produce partial-sums
|
||||
k_split.wait();
|
||||
|
||||
// Configure epilogue as to whether the thread block is a secondary
|
||||
// accumulator in an inter-block k-splitting scheme
|
||||
if (k_split.is_secondary_accumulator())
|
||||
epilogue_op.set_secondary_accumulator();
|
||||
|
||||
// Whether or not the addend from C needs loading
|
||||
bool must_init_addend = epilogue_op.must_init_addend();
|
||||
|
||||
int warp_base_x = grid_raster.block_item_coords.x + block_warp_item_coords.x;
|
||||
int warp_base_y = grid_raster.block_item_coords.y + block_warp_item_coords.y;
|
||||
|
||||
int constexpr SmemStride = WmmaItemsY;
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
|
||||
// Compute shape of one accumulator read/modify/write operation
|
||||
int constexpr ItemsY = (WmmaItemsY);
|
||||
int constexpr ItemsX = (32 / ItemsY);
|
||||
int constexpr IterationsX = WmmaItemsX / ItemsX;
|
||||
|
||||
// Compute a rasterization of warp lanes across the WMMA tile.
|
||||
int lane_id = (threadIdx.x % 32);
|
||||
int lane_read_x = (lane_id / ItemsY);
|
||||
int lane_read_y = (lane_id % ItemsY);
|
||||
|
||||
accum_t *smem_scratch = scratch->epilogue[warp_id];
|
||||
accum_t const *smem_read_ptr = smem_scratch + lane_read_y + lane_read_x * SmemStride;
|
||||
|
||||
#pragma unroll
|
||||
for (int xb = 0; xb < WmmaBlocksX; ++xb)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int yb = 0; yb < WmmaBlocksY; ++yb)
|
||||
{
|
||||
// Store accumulator tile to SMEM
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
smem_scratch,
|
||||
accumulator.accumulators[xb][yb],
|
||||
SmemStride,
|
||||
matrix_layout<matrix_transform_t::NonTranspose>::kind);
|
||||
|
||||
// Synchronize threads within the warp
|
||||
__syncthreads();
|
||||
|
||||
// Compute lane coordinates so that each thread efficiently accesses SMEM.
|
||||
int c_x = (warp_base_x + (xb) * WmmaItemsX + lane_read_x);
|
||||
int c_y = (warp_base_y + (yb) * WmmaItemsY + lane_read_y);
|
||||
|
||||
// Compute guard predicate by comparing against problem dimensions.
|
||||
bool pred = c_y < dim_m;
|
||||
|
||||
// Compute output pointer from lane coordinates
|
||||
int c_index = c_x * dim_m + c_y;
|
||||
accum_t *c_ptr = reinterpret_cast<accum_t *>(d_c) + c_x * dim_m + c_y;
|
||||
|
||||
// Iterate over columns of output tile. Load from SMEM, compute epilogue operation,
|
||||
// and stream output to global memory
|
||||
#pragma unroll
|
||||
for (int item_x = 0; item_x < IterationsX; ++item_x)
|
||||
{
|
||||
accum_t accum = smem_read_ptr[item_x * ItemsX * SmemStride];
|
||||
accum_t c_element = 0;
|
||||
|
||||
// Filter against problem dimensions as the warp iterates across the columns of
|
||||
// output.
|
||||
pred = (pred && ((c_x + item_x * ItemsX) < dim_n));
|
||||
|
||||
if (must_init_addend && pred)
|
||||
{
|
||||
// NB: inline PTX to utilize strong operations for inter-block synchronization.
|
||||
// The following is equivalent to:
|
||||
//
|
||||
// c_element = c_ptr[0];
|
||||
asm volatile ("ld.global.cg.f32 %0, [%1];\n" : "=f"(c_element) : "l"(c_ptr));
|
||||
}
|
||||
|
||||
c_element = epilogue_op(accum, c_element, c_index);
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// NB: inline PTX to utilize strong operations for inter-block synchronization.
|
||||
// The following is equivalent to:
|
||||
//
|
||||
// c_ptr[0] = c_element;
|
||||
|
||||
asm volatile ("st.global.cg.f32 [%0], %1;\n" : : "l"(c_ptr), "f"(c_element));
|
||||
}
|
||||
|
||||
// Increment output pointer
|
||||
c_ptr += dim_m * ItemsX;
|
||||
c_index += dim_m * ItemsX;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// Signal k-split successor thread_block
|
||||
k_split.signal();
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Tile consumption
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Consume a tile of A and B each
|
||||
*/
|
||||
template <bool DoGlobalPrefetch>
|
||||
inline __device__
|
||||
void consume_tile()
|
||||
{
|
||||
// Request global prefetch for next tile on first strip
|
||||
if (DoGlobalPrefetch)
|
||||
{
|
||||
loader_b.request();
|
||||
loader_b.next();
|
||||
loader_a.request();
|
||||
loader_a.next();
|
||||
}
|
||||
|
||||
// Unroll BlockDpVectorsK iterations of outer-product accumulations
|
||||
#pragma unroll
|
||||
for (int iteration = 0; iteration < WmmaUnrollCount; ++iteration)
|
||||
{
|
||||
int tile_offset_k = iteration * WmmaItemsK;
|
||||
|
||||
// Active load-from-shared index
|
||||
int active_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration) % 2);
|
||||
|
||||
// Next load-from-shared index
|
||||
int next_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration + 1) % 2);
|
||||
|
||||
// The last unrolled iteration commits the global fetches
|
||||
if ((iteration == WmmaUnrollCount - 1) && DoGlobalPrefetch)
|
||||
{
|
||||
// If not using two pages of scratch tiles, protect the above prefetch loads from
|
||||
// the committing writes below
|
||||
if (!UseDoubleScratchTiles)
|
||||
{
|
||||
__syncthreads();
|
||||
}
|
||||
else
|
||||
{
|
||||
page_idx = (page_idx ? 0 : 1);
|
||||
}
|
||||
|
||||
// Commit global prefetch data to scratch page
|
||||
loader_a.template commit<LdmSmemA>(&scratch->pages[page_idx].alias().block_a[0][0]);
|
||||
loader_b.template commit<LdmSmemB>(&scratch->pages[page_idx].alias().block_b[0][0]);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Accumulate this dp-stripe product
|
||||
accumulator.multiply_accumulate(
|
||||
local_slices_a[active_lds_idx],
|
||||
local_slices_b[active_lds_idx]);
|
||||
|
||||
// Request local prefetch for next strip
|
||||
request_local_prefetch(
|
||||
local_slices_a[next_lds_idx],
|
||||
local_slices_b[next_lds_idx],
|
||||
(tile_offset_k + WmmaItemsK) % BlockItemsK);
|
||||
}
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// GEMM API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Compute GEMM
|
||||
*/
|
||||
inline __device__
|
||||
void run()
|
||||
{
|
||||
// Quit if the thread block is fully out-of-bounds
|
||||
if (grid_raster.is_block_oob(dim_m, dim_n))
|
||||
{
|
||||
asm volatile("exit;");
|
||||
}
|
||||
|
||||
// Request global prefetch of first tile
|
||||
loader_a.request();
|
||||
loader_a.next();
|
||||
loader_b.request();
|
||||
loader_b.next();
|
||||
|
||||
// Commit global prefetch of first tile to shared memory
|
||||
loader_a.template commit<LdmSmemA>(&scratch->pages[page_idx].alias().block_a[0][0]);
|
||||
loader_b.template commit<LdmSmemB>(&scratch->pages[page_idx].alias().block_b[0][0]);
|
||||
|
||||
// Advance to next A,B tiles in K-axis
|
||||
block_item_coords_k += BlockItemsK;
|
||||
|
||||
// Synchronize shared tiles and prepared accumulator
|
||||
__syncthreads();
|
||||
|
||||
// Initialize thread's slice of accumulators
|
||||
accumulator.init();
|
||||
|
||||
// Request first iteration of local prefetch strips
|
||||
request_local_prefetch(
|
||||
local_slices_a[0],
|
||||
local_slices_b[0],
|
||||
0);
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Consume tiles in A and B along the K-axis (all but last tile)
|
||||
#pragma unroll 1
|
||||
while (block_item_coords_k < block_end_item_k)
|
||||
{
|
||||
consume_tile<true>();
|
||||
|
||||
// Advance to next A,B tiles in K-axis
|
||||
block_item_coords_k += BlockItemsK;
|
||||
}
|
||||
|
||||
consume_tile<false>();
|
||||
|
||||
//
|
||||
// Eplilogue
|
||||
//
|
||||
|
||||
// prevent overwriting SMEM until all warps have finished loading data
|
||||
__syncthreads();
|
||||
|
||||
// store accumulator tile to global memory
|
||||
epilogue();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif
|
||||
57
cutlass/gemm/clear_accumulators.h
Normal file
57
cutlass/gemm/clear_accumulators.h
Normal file
@ -0,0 +1,57 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for efficiently clearing accumulator tiles.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_ = 1>
|
||||
struct ClearAccumulators {
|
||||
/// The shared storage.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
|
||||
|
||||
/// Clear the fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void clear(Fragment_& fragment) {
|
||||
fragment.clear();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
127
cutlass/gemm/dgemm_traits.h
Normal file
127
cutlass/gemm/dgemm_traits.h
Normal file
@ -0,0 +1,127 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural traits of double-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct DgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
double,
|
||||
/// The scalar type for B.
|
||||
double,
|
||||
/// The scalar type for C.
|
||||
double,
|
||||
/// The scalar type for D.
|
||||
double,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, double, double, double>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
2,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
2,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
2,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 64, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<double>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// The number of doubles loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of doubles loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The DGEMM config.
|
||||
typename GemmConfig_ =
|
||||
DgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct DgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,542 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* GEMM kernel entrypoint and dispatch stub
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
#include "block_task.h"
|
||||
#include "block_task_wmma.h"
|
||||
#include "grid_raster.h"
|
||||
#include "dispatch_policies.h"
|
||||
#include "k_split_control.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* param_pack
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Parameter-pack structure
|
||||
*
|
||||
* Kernel launch latency is reduced when kernel arguments are wrapped into
|
||||
* a single parameter
|
||||
*/
|
||||
template <
|
||||
typename value_t,
|
||||
typename accum_t,
|
||||
typename epilogue_op_t>
|
||||
struct param_pack
|
||||
{
|
||||
int m; ///< Height in rows of op(A) and C
|
||||
int n; ///< Width in columns of op(B) and C
|
||||
int k; ///< Width in columns of op(A) and height in rows of op(B)
|
||||
k_split_control k_split; ///< Abstraction for controlling inter-block k-splitting
|
||||
value_t *d_a; ///< Pointer to matrix A array values
|
||||
value_t *d_b; ///< Pointer to matrix B array values
|
||||
accum_t *d_c; ///< Pointer to matrix C array values
|
||||
epilogue_op_t epilogue_op;
|
||||
|
||||
param_pack(
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
k_split_control k_split, ///< Abstraction for controlling inter-block k-splitting
|
||||
epilogue_op_t op, ///< Epilogue operation to update matrix C
|
||||
value_t *d_a, ///< Pointer to matrix A array values
|
||||
value_t *d_b, ///< Pointer to matrix B array values
|
||||
accum_t *d_c) ///< Pointer to matrix C array values
|
||||
:
|
||||
m(m),
|
||||
n(n),
|
||||
k(k),
|
||||
k_split(k_split),
|
||||
epilogue_op(op),
|
||||
d_a(d_a),
|
||||
d_b(d_b),
|
||||
d_c(d_c)
|
||||
{}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Conditionally select the appropriate GEMM threadblock task
|
||||
******************************************************************************/
|
||||
|
||||
/// Conditional selection for block task
|
||||
template <
|
||||
math_operation_class_t math_op, ///<
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) for A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) for B operand
|
||||
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
|
||||
int LdgAlignC, ///< Alignment (in bytes) for C operand
|
||||
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
|
||||
>
|
||||
struct gemm_block_task;
|
||||
|
||||
/// Scalar math operations
|
||||
template <
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) for A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) for B operand
|
||||
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
|
||||
int LdgAlignC, ///< Alignment (in bytes) for C operand
|
||||
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
|
||||
>
|
||||
struct gemm_block_task<
|
||||
math_operation_class_t::scalar,
|
||||
block_task_policy_t,
|
||||
value_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
LdgAlignA,
|
||||
TransformB,
|
||||
LdgAlignB,
|
||||
epilogue_op_t,
|
||||
LdgAlignC,
|
||||
AllowRaggedTiles
|
||||
>
|
||||
{
|
||||
// Parameterize task type
|
||||
typedef block_task<
|
||||
block_task_policy_t,
|
||||
value_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
LdgAlignA,
|
||||
TransformB,
|
||||
LdgAlignB,
|
||||
epilogue_op_t,
|
||||
LdgAlignC,
|
||||
AllowRaggedTiles> type;
|
||||
};
|
||||
|
||||
/// Matrix math operations
|
||||
template <
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) for A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) for B operand
|
||||
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
|
||||
int LdgAlignC, ///< Alignment (in bytes) for C operand
|
||||
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
|
||||
>
|
||||
struct gemm_block_task<
|
||||
math_operation_class_t::matrix,
|
||||
block_task_policy_t,
|
||||
value_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
LdgAlignA,
|
||||
TransformB,
|
||||
LdgAlignB,
|
||||
epilogue_op_t,
|
||||
LdgAlignC,
|
||||
AllowRaggedTiles>
|
||||
{
|
||||
|
||||
#if defined(WMMA) // conditional compilation with WMMA headers
|
||||
|
||||
// Parameterize task type
|
||||
typedef block_task_wmma<
|
||||
block_task_policy_t,
|
||||
value_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
LdgAlignA,
|
||||
TransformB,
|
||||
LdgAlignB,
|
||||
epilogue_op_t,
|
||||
LdgAlignC,
|
||||
AllowRaggedTiles> type;
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
/******************************************************************************
|
||||
* GEMM kernel entrypoint
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM kernel
|
||||
*
|
||||
* NB: Not sure why NVVM is doing stuff with "__launch_bounds__" instead of just
|
||||
* passing it along to PTXAS, but it is currently resulting in less optimal codegen
|
||||
*/
|
||||
template <
|
||||
math_operation_class_t math_op, ///< Indicates which class of math operation to select
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
int LdgAlignA, ///< Alignment of A matrix elements in bytes
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
int LdgAlignB, ///< Alignment of B matrix elements in bytes
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
typename epilogue_op_t, ///< Epilogue operation applied to update matrix C
|
||||
int LdgAlignC, ///< Alignment of C elements in bytes
|
||||
bool AllowRaggedTiles> ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
|
||||
__global__ void kernel(param_pack<value_t, accum_t, epilogue_op_t> pack)
|
||||
{
|
||||
// Parameterize task type
|
||||
typedef typename gemm_block_task<
|
||||
math_op,
|
||||
block_task_policy_t,
|
||||
value_t,
|
||||
accum_t,
|
||||
TransformA,
|
||||
LdgAlignA,
|
||||
TransformB,
|
||||
LdgAlignB,
|
||||
epilogue_op_t,
|
||||
LdgAlignC,
|
||||
AllowRaggedTiles>::type block_task_t;
|
||||
|
||||
// Declare statically-allocated shared storage
|
||||
__shared__ typename block_task_t::scratch_storage_t smem;
|
||||
|
||||
// Construct and run the task
|
||||
block_task_t(
|
||||
&smem,
|
||||
pack.d_a,
|
||||
pack.d_b,
|
||||
pack.d_c,
|
||||
pack.epilogue_op,
|
||||
pack.m,
|
||||
pack.n,
|
||||
pack.k,
|
||||
pack.k_split).run();
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Launch configuration description returned to the caller
|
||||
******************************************************************************/
|
||||
|
||||
/// Return details about the launch configuration to the caller
|
||||
struct launch_configuration
|
||||
{
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// cudaError_t resulting from grid launch
|
||||
cudaError_t result;
|
||||
|
||||
/// Extent of a thread block's partition along the GEMM K-axis
|
||||
int split_k;
|
||||
|
||||
/// Kernel grid extents in thread blocks
|
||||
dim3 grid;
|
||||
|
||||
/// Thread block extents in threads
|
||||
dim3 block;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
launch_configuration():
|
||||
result(cudaSuccess),
|
||||
split_k(0),
|
||||
grid(0, 0, 0),
|
||||
block(0, 0, 0) {
|
||||
|
||||
}
|
||||
|
||||
/// Conversion from cudaError_t
|
||||
launch_configuration(cudaError_t result):
|
||||
result(result),
|
||||
split_k(1),
|
||||
grid(0, 0, 0),
|
||||
block(0, 0, 0) {
|
||||
|
||||
}
|
||||
|
||||
/// Launch configuration for Cutlass kernels
|
||||
launch_configuration(
|
||||
cudaError_t result,
|
||||
int split_k,
|
||||
dim3 grid,
|
||||
dim3 block
|
||||
):
|
||||
result(result),
|
||||
split_k(split_k),
|
||||
grid(grid),
|
||||
block(block) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Dispatch stub
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM dispatch stub
|
||||
*
|
||||
* This function also serves as the autotuning entrypoint to evaluate different
|
||||
* tuning parameterizations of kernel.
|
||||
*/
|
||||
template <
|
||||
math_operation_class_t math_op, ///< Indicates which class of math operation to select
|
||||
typename block_task_policy_t, ///< Parameterization of block_task_policy
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
int LdgAlignA, ///< Alignment of A matrix elements in bytes
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
int LdgAlignB, ///< Alignment of B matrix elements in bytes
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
typename epilogue_op_t, ///< Epilogue operation
|
||||
int LdgAlignC, ///< Alignment of C matrix elements in bytes
|
||||
bool AllowRaggedTiles, ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
|
||||
typename kernel_ptr_t> ///< GEMM kernel function pointer type
|
||||
launch_configuration dispatch(
|
||||
kernel_ptr_t kernel_ptr, ///< GEMM kernel function pointer
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
|
||||
value_t *d_a, ///< Device pointer to matrix A array values
|
||||
value_t *d_b, ///< Device pointer to matrix B array values
|
||||
accum_t *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = true) ///< Whether or not to synchronize the stream after every kernel launch
|
||||
/// to check for errors. Also causes launch configurations to be printed
|
||||
/// to the console if DEBUG is defined. Default is \p false.
|
||||
{
|
||||
// Thread block rasterization type
|
||||
typedef grid_raster<
|
||||
block_task_policy_t::BlockItemsY,
|
||||
block_task_policy_t::BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
block_task_policy_t::RasterStrategy>
|
||||
grid_raster_t;
|
||||
|
||||
launch_configuration config;
|
||||
|
||||
// Compute block dims
|
||||
config.block = dim3(block_task_policy_t::BlockThreads);
|
||||
|
||||
// Compute shared memory
|
||||
int dynamic_smem_bytes = 0;
|
||||
|
||||
// Compute occupancy
|
||||
int max_sm_occupancy;
|
||||
if (CUDA_PERROR_DEBUG(config.result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_sm_occupancy,
|
||||
kernel_ptr,
|
||||
config.block.x * config.block.y,
|
||||
dynamic_smem_bytes)))
|
||||
{
|
||||
return config;
|
||||
}
|
||||
|
||||
// Compute grid extents
|
||||
config.grid = grid_raster_t::grid_dims(m, n);
|
||||
|
||||
// Get SM count
|
||||
int sm_count;
|
||||
if (CUDA_PERROR_DEBUG(config.result = get_sm_count(sm_count)))
|
||||
return config;
|
||||
|
||||
// Get k-split flag storage (TODO: make a pool)
|
||||
int *d_flags;
|
||||
if (CUDA_PERROR_DEBUG(config.result = cudaGetSymbolAddress((void**) &d_flags, d_flags_split_k)))
|
||||
return config;
|
||||
|
||||
// Construct k-split coordinator
|
||||
k_split_control k_split(
|
||||
d_flags,
|
||||
sm_count,
|
||||
max_sm_occupancy,
|
||||
k,
|
||||
block_task_policy_t::BlockItemsK,
|
||||
config.block,
|
||||
config.grid); // in,out
|
||||
|
||||
config.split_k = k_split.split_k;
|
||||
|
||||
// Log kernel configuration
|
||||
if (debug_synchronous)
|
||||
{
|
||||
// Compute tiling efficiency
|
||||
float block_tiling_efficiency = float(block_task_policy_t::BlockItemsY * block_task_policy_t::BlockItemsX) /
|
||||
float(block_task_policy_t::BlockItemsY + block_task_policy_t::BlockItemsX);
|
||||
|
||||
float tiling_efficiency = block_tiling_efficiency;
|
||||
|
||||
float wave_efficiency = k_split.get_wave_efficiency(
|
||||
sm_count, max_sm_occupancy, config.block, config.grid);
|
||||
|
||||
CUDA_LOG_DEBUG("Final wave_efficiency %.4f, tiling_efficiency %.4f\n",
|
||||
wave_efficiency, tiling_efficiency);
|
||||
|
||||
CUDA_LOG_DEBUG("Invoking kernel<<<(%d, %d, %d), (%d.y,%d.x), %d, %lld>>>(), %d SM occupancy, %d split_k\n",
|
||||
config.grid.x, config.grid.y, config.grid.z,
|
||||
config.block.y, config.block.x,
|
||||
dynamic_smem_bytes,
|
||||
(long long) stream,
|
||||
max_sm_occupancy,
|
||||
k_split.split_k);
|
||||
}
|
||||
|
||||
// Construct parameter-pack
|
||||
param_pack<value_t, accum_t, epilogue_op_t> pack(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
k_split,
|
||||
epilogue_op,
|
||||
d_a,
|
||||
d_b,
|
||||
d_c);
|
||||
|
||||
// Prepare k-split coordinator
|
||||
if (CUDA_PERROR_DEBUG(config.result = k_split.prepare(stream, debug_synchronous)))
|
||||
{
|
||||
return config;
|
||||
}
|
||||
|
||||
// Invoke kernel
|
||||
kernel_ptr<<< config.grid, config.block, dynamic_smem_bytes, stream >>>(pack);
|
||||
|
||||
// Check for failure to launch
|
||||
if (CUDA_PERROR_DEBUG(config.result = cudaPeekAtLastError()))
|
||||
return config;
|
||||
|
||||
// Sync the stream if specified to flush runtime errors
|
||||
if (debug_synchronous && (CUDA_PERROR_DEBUG(config.result = cudaStreamSynchronize(stream))))
|
||||
return config;
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* GEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Computes gemm on device matrices
|
||||
*/
|
||||
template <
|
||||
tiling_strategy::kind_t TilingStrategy, ///< Tile-sizing classification
|
||||
math_operation_class_t math_op, ///< Indicates which class of math operation to select
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
int LdgAlignA, ///< Alignment (in bytes) of A operand
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
int LdgAlignB, ///< Alignment (in bytes) of B operand
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t, ///< Accumulator value type (matrix C and scalars)
|
||||
typename epilogue_op_t, ///< Epilogue operation to update matrix C
|
||||
int LdgAlignC> ///< Alignment (in bytes) of C operand
|
||||
launch_configuration device_gemm(
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
|
||||
value_t *d_a, ///< Device pointer to matrix A array values
|
||||
value_t *d_b, ///< Device pointer to matrix B array values
|
||||
accum_t *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to
|
||||
/// check for errors. Also causes launch configurations to be printed to
|
||||
/// the console if DEBUG is defined. Default is \p false.
|
||||
{
|
||||
// Parameterize an task policy type
|
||||
// (TODO: use a policy dispatch mechanism based upon SM version)
|
||||
typedef gemm_policy<value_t, accum_t, TransformA, TransformB, TilingStrategy> block_task_policy_t;
|
||||
|
||||
// AllowRaggedTiles-tile check
|
||||
if ((m % block_task_policy_t::BlockItemsY != 0) ||
|
||||
(n % block_task_policy_t::BlockItemsX != 0) ||
|
||||
(k % block_task_policy_t::BlockItemsK != 0))
|
||||
{
|
||||
// Needs ragged tile-handling
|
||||
static const bool AllowRaggedTiles = true;
|
||||
|
||||
return dispatch<math_op, block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>(
|
||||
kernel<math_op,block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue_op,
|
||||
d_a,
|
||||
d_b,
|
||||
d_c,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Does not need ragged tile-handling
|
||||
static const bool AllowRaggedTiles = false;
|
||||
|
||||
return dispatch<math_op, block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>(
|
||||
kernel<math_op,block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue_op,
|
||||
d_a,
|
||||
d_b,
|
||||
d_c,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,661 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Architecture-specific GEMM block_task policies
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
#include "block_task.h"
|
||||
#include "grid_raster.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* tiling_strategy
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Enumeration of tile-sizing granularities
|
||||
*/
|
||||
struct tiling_strategy : printable_t
|
||||
{
|
||||
/// \brief Enumerants
|
||||
enum kind_t
|
||||
{
|
||||
Unknown,
|
||||
Small,
|
||||
Medium,
|
||||
Large,
|
||||
Tall,
|
||||
Wide,
|
||||
Huge,
|
||||
};
|
||||
|
||||
/// Enumerant value
|
||||
kind_t kind;
|
||||
|
||||
/// Default constructor
|
||||
tiling_strategy() : kind(Unknown) {}
|
||||
|
||||
/// Copy constructor
|
||||
tiling_strategy(const kind_t &other_kind) : kind(other_kind) {}
|
||||
|
||||
/// Cast to kind_t
|
||||
operator kind_t() const { return kind; }
|
||||
|
||||
/// Returns the instance as a string
|
||||
__host__ __device__ inline
|
||||
char const* to_string() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case Small: return "small";
|
||||
case Medium: return "medium";
|
||||
case Large: return "large";
|
||||
case Tall: return "tall";
|
||||
case Wide: return "wide";
|
||||
case Huge: return "huge";
|
||||
case Unknown:
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert the formatted instance into the output stream
|
||||
void print(std::ostream& out) const { out << to_string(); }
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* GEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for sgemm
|
||||
*/
|
||||
template <
|
||||
typename value_t,
|
||||
typename accum_t,
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
tiling_strategy::kind_t TilingStrategy> ///< Tile-sizing classification
|
||||
struct gemm_policy;
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* SGEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for small sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Small> :
|
||||
block_task_policy<
|
||||
16, // _BlockItemsY
|
||||
16, // _BlockItemsX
|
||||
16, // _BlockItemsK
|
||||
2, // _ThreadItemsY
|
||||
2, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for medium sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Medium> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Large> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for tall sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Tall> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for wide sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Wide> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for huge sgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Huge> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* DGEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for small dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Small> :
|
||||
block_task_policy<
|
||||
16, // _BlockItemsY
|
||||
16, // _BlockItemsX
|
||||
16, // _BlockItemsK
|
||||
2, // _ThreadItemsY
|
||||
2, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for medium dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Medium> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
16, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Large> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for tall dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Tall> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for wide dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Wide> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for huge dgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Huge> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* HGEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for small hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Small> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for medium hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Medium> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
16, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Large> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
16, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for tall hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Tall> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
16, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for wide hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Wide> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for huge hgemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Huge> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
8, // _BlockItemsK
|
||||
16, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* IGEMM
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for small igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Small> :
|
||||
block_task_policy<
|
||||
16, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for medium igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Medium> :
|
||||
block_task_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Large> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Tall> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
64, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
4, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for large igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Wide> :
|
||||
block_task_policy<
|
||||
64, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
64, // _BlockItemsK
|
||||
4, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
/**
|
||||
* GEMM task policy specialization for huge igemm
|
||||
*/
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Huge> :
|
||||
block_task_policy<
|
||||
128, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
8, // _ThreadItemsY
|
||||
8, // _ThreadItemsX
|
||||
false, // _UseDoubleScratchTiles
|
||||
grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* WMMA GEMM
|
||||
******************************************************************************/
|
||||
|
||||
// WMMA is a preview feature in CUDA. Conditionally enable wmma_gemm policies.
|
||||
#if defined(WMMA)
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<half, float, TransformA, TransformB, tiling_strategy::Small> :
|
||||
gemm::block_task_wmma_policy<
|
||||
16, // _BlockItemsY
|
||||
16, // _BlockItemsX
|
||||
16, // _BlockItemsK
|
||||
16, // _WarpItemsY
|
||||
16, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy<half, float, TransformA, TransformB, tiling_strategy::Medium> :
|
||||
gemm::block_task_wmma_policy<
|
||||
32, // _BlockItemsY
|
||||
32, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
32, // _WarpItemsY
|
||||
32, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Large> :
|
||||
gemm::block_task_wmma_policy<
|
||||
64, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
32, // _BlockItemsK
|
||||
32, // _WarpItemsY
|
||||
64, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Tall> :
|
||||
gemm::block_task_wmma_policy<
|
||||
128, // _BlockItemsY
|
||||
64, // _BlockItemsX
|
||||
64, // _BlockItemsK
|
||||
32, // _WarpItemsY
|
||||
64, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Wide> :
|
||||
gemm::block_task_wmma_policy<
|
||||
64, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
64, // _BlockItemsK
|
||||
32, // _WarpItemsY
|
||||
64, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
template <
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
|
||||
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Huge> :
|
||||
gemm::block_task_wmma_policy<
|
||||
128, // _BlockItemsY
|
||||
128, // _BlockItemsX
|
||||
64, // _BlockItemsK
|
||||
32, // _WarpItemsY
|
||||
64, // _WarpItemsX
|
||||
16, // _WmmaItemsY
|
||||
16, // _WmmaItemsX
|
||||
16, // _WmmaItemsK
|
||||
false, // _UseDoubleScratchTiles
|
||||
gemm::grid_raster_strategy::Default> // _RasterStrategy
|
||||
{};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,223 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Abstraction for exposing architecture-specific "dot-product-accumulate"
|
||||
* ISA operations
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* dp_accummulate
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
/**
|
||||
* \brief Abstraction for exposing architecture-specific "dot-product-accumulate"
|
||||
* ISA operations
|
||||
*
|
||||
* Given two K-component vectors a and b having type value_t[K] and an addend c
|
||||
* of type accum_t, the "dot-product-accumulate" of type accum_t is computed
|
||||
* as d = x[0]*y[0] + x[1]*y[1] + ... + x[K-1]*y[K-1] + c.
|
||||
*
|
||||
* We use the notation "dpK" to connote a K-component dot-product-accumulate.
|
||||
* For example, "dp1" is a simple multiply-add.
|
||||
*
|
||||
* For given pairing of value_t and accum_t types, the corresponding
|
||||
* dp_accummulate class will:
|
||||
*
|
||||
* - Define the member-type dp_vector_t as the appropriate K-component vector
|
||||
* type needed to leverage architecture-specific "dot-product accumulate"
|
||||
* ISA operations.
|
||||
* - Implement the corresponding dot-product operation between two dp_vector_t
|
||||
* inputs a and b.
|
||||
*
|
||||
*/
|
||||
template <
|
||||
typename value_t, ///< Component value type
|
||||
typename accum_t> ///< Accumulator value type
|
||||
struct dp_accummulate;
|
||||
|
||||
|
||||
|
||||
/// Default "dp1" dot-product-accumulate traits specialization for value_t->accum_t
|
||||
template <
|
||||
typename value_t, ///< Component value type
|
||||
typename accum_t> ///< Accumulator value type
|
||||
struct dp_accummulate
|
||||
{
|
||||
/// Single-component "dp1" dot-product vector type
|
||||
typedef value_t dp_vector_t;
|
||||
|
||||
|
||||
/// Compute "dp1" float->float
|
||||
inline __device__
|
||||
static void mad(
|
||||
float &d,
|
||||
const float &a,
|
||||
const float &b,
|
||||
const float &c)
|
||||
{
|
||||
asm volatile ( "fma.rn.f32 %0, %1, %2, %3;\n"
|
||||
: "=f"(d) : "f"(a), "f"(b), "f"(c));
|
||||
}
|
||||
|
||||
|
||||
/// Compute "dp1" double->double
|
||||
inline __device__
|
||||
static void mad(
|
||||
double &d,
|
||||
const double &a,
|
||||
const double &b,
|
||||
const double &c)
|
||||
{
|
||||
asm volatile ("fma.rn.f64 %0, %1, %2, %3;\n"
|
||||
: "=d"(d) : "d"(a), "d"(b), "d"(c));
|
||||
}
|
||||
|
||||
|
||||
/// Compute "dp1" int16_t->int32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
int32_t &d,
|
||||
const int16_t &a,
|
||||
const int16_t &b,
|
||||
const int32_t &c)
|
||||
{
|
||||
asm volatile ("mad.wide.s16 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "h"(a), "h"(b), "r"(c));
|
||||
}
|
||||
|
||||
|
||||
/// Compute "dp1" uint16_t->uint32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
uint32_t &d,
|
||||
const uint16_t &a,
|
||||
const uint16_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
asm volatile ("mad.wide.u16 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "h"(a), "h"(b), "r"(c));
|
||||
}
|
||||
|
||||
|
||||
/// Compute "dp1" int32_t->int32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
int32_t &d,
|
||||
const int32_t &a,
|
||||
const int32_t &b,
|
||||
const int32_t &c)
|
||||
{
|
||||
asm volatile ("mad.lo.s32 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
|
||||
|
||||
/// Compute "dp1" uint32_t->uint32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
uint32_t &d,
|
||||
const uint32_t &a,
|
||||
const uint32_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
asm volatile ("mad.lo.u32 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
#if (CUTLASS_ARCH >= 610) // Specializations only enabled for Pascal SM610+
|
||||
|
||||
|
||||
/// "dp4" dot-product-accumulate traits specialization for int8_t->int32_t
|
||||
template <>
|
||||
struct dp_accummulate<
|
||||
int8_t, ///< Component value type
|
||||
int32_t> ///< Accumulator value type
|
||||
{
|
||||
/// Four-component signed "idp4"
|
||||
typedef int32_t dp_vector_t;
|
||||
|
||||
/// Compute "dp4" int16_t->int32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
int32_t &d,
|
||||
const int32_t &a,
|
||||
const int32_t &b,
|
||||
const int32_t &c)
|
||||
{
|
||||
asm volatile ( "dp4a.s32.s32 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// "dp4" dot-product-accumulate traits specialization for uint8_t->uint32_t
|
||||
template <>
|
||||
struct dp_accummulate<
|
||||
uint8_t, ///< Component value type
|
||||
uint32_t> ///< Accumulator value type
|
||||
{
|
||||
/// Four-component unsigned "idp4"
|
||||
typedef uint32_t dp_vector_t;
|
||||
|
||||
/// Compute "dp4" uint16_t->uint32_t
|
||||
inline __device__
|
||||
static void mad(
|
||||
uint32_t &d,
|
||||
const uint32_t &a,
|
||||
const uint32_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
asm volatile ( "dp4a.u32.u32 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#endif // Specializations only enabled for Pascal SM610+
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,104 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Epilogue operation to compute final output
|
||||
*/
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
//// Used by GEMM to compute the final result C <= alpha * accumulator + beta * C
|
||||
template <
|
||||
typename accum_t,
|
||||
typename output_t,
|
||||
typename scalar_t
|
||||
>
|
||||
class blas_scaled_epilogue
|
||||
{
|
||||
public:
|
||||
|
||||
scalar_t alpha;
|
||||
scalar_t beta;
|
||||
|
||||
inline __device__ __host__
|
||||
blas_scaled_epilogue(
|
||||
scalar_t alpha,
|
||||
scalar_t beta)
|
||||
:
|
||||
alpha(alpha),
|
||||
beta(beta)
|
||||
{}
|
||||
|
||||
|
||||
/// Epilogue operator
|
||||
inline __device__ __host__
|
||||
output_t operator()(
|
||||
accum_t accumulator,
|
||||
output_t c,
|
||||
size_t idx) const
|
||||
{
|
||||
return output_t(alpha * scalar_t(accumulator) + beta * scalar_t(c));
|
||||
}
|
||||
|
||||
|
||||
/// Epilogue operator
|
||||
inline __device__ __host__
|
||||
output_t operator()(
|
||||
accum_t accumulator,
|
||||
size_t idx) const
|
||||
{
|
||||
return output_t(alpha * scalar_t(accumulator));
|
||||
}
|
||||
|
||||
/**
|
||||
* Configure epilogue as to whether the thread block is a secondary
|
||||
* accumulator in an inter-block k-splitting scheme
|
||||
*/
|
||||
inline __device__
|
||||
void set_secondary_accumulator()
|
||||
{
|
||||
beta = scalar_t(1);
|
||||
}
|
||||
|
||||
|
||||
/// Return whether the beta-scaled addend needs initialization
|
||||
inline __device__
|
||||
bool must_init_addend()
|
||||
{
|
||||
return (beta != scalar_t(0));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
344
cutlass/gemm/gemm.h
Normal file
344
cutlass/gemm/gemm.h
Normal file
@ -0,0 +1,344 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Gemm_>
|
||||
__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, shared_storage);
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Index_ = int>
|
||||
struct GemmDesc {
|
||||
/// The dimensions of the GEMM.
|
||||
Index_ m, n, k;
|
||||
/// The alpha/beta scaling values.
|
||||
Scalar_ alpha, beta;
|
||||
/// The source matrix A.
|
||||
void const* d_a;
|
||||
/// The stride for A.
|
||||
Index_ lda;
|
||||
/// The source matrix B.
|
||||
void const* d_b;
|
||||
/// The stride for B.
|
||||
Index_ ldb;
|
||||
/// The source matrix C.
|
||||
void const* d_c;
|
||||
/// The stride for C.
|
||||
Index_ ldc;
|
||||
/// The destination matrix D.
|
||||
void* d_d;
|
||||
/// The stride for D.
|
||||
Index_ ldd;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct Gemm {
|
||||
/// This class.
|
||||
typedef Gemm<GemmTraits_> This_;
|
||||
/// The traits.
|
||||
typedef GemmTraits_ Traits;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The scalar for A.
|
||||
typedef typename Traits::ScalarA ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef typename Traits::ScalarB ScalarB;
|
||||
/// The scalar in the epilogue.
|
||||
typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
|
||||
/// The scalar for C.
|
||||
typedef typename Traits::Epilogue::ScalarC ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename Traits::Epilogue::ScalarD ScalarD;
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
/// The params.
|
||||
struct Params : public Traits::Params {
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
GemmDesc<ScalarEpilogue, Index> desc;
|
||||
desc.m = m;
|
||||
desc.n = n;
|
||||
desc.k = k;
|
||||
desc.alpha = alpha;
|
||||
desc.beta = beta;
|
||||
desc.d_a = reinterpret_cast<void const*>(d_a);
|
||||
desc.lda = lda;
|
||||
desc.d_b = reinterpret_cast<void const*>(d_b);
|
||||
desc.ldb = ldb;
|
||||
desc.d_c = reinterpret_cast<void const*>(d_c);
|
||||
desc.ldc = ldc;
|
||||
desc.d_d = reinterpret_cast<void*>(d_d);
|
||||
desc.ldd = ldd;
|
||||
return Traits::Params::initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void const* params_ = reinterpret_cast<void const*>(¶ms);
|
||||
|
||||
return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>),
|
||||
grid,
|
||||
block,
|
||||
const_cast<void**>(¶ms_),
|
||||
0,
|
||||
stream);
|
||||
}
|
||||
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(CUfunction kernel,
|
||||
Params const& params,
|
||||
CUstream stream = CU_STREAM_LEGACY) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(¶ms))};
|
||||
|
||||
// return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>), grid, block,
|
||||
// const_cast<void**>(¶ms_), 0, stream);
|
||||
CUresult result = cuLaunchKernel(
|
||||
kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
|
||||
|
||||
if (result != CUDA_SUCCESS) {
|
||||
return cudaErrorLaunchFailure;
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Consume a single iteration of the loop.
|
||||
template <bool kIsLastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
|
||||
typename Traits::SharedLoadStream& shared_load_stream,
|
||||
typename Traits::MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
// If that's the last "load iteration" update the predicates.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.move_to_residue<false>(outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.copy();
|
||||
}
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
// Make sure the data from shared memory has been entirely consumed.
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.commit();
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Trigger the loads for the next iteration (if needed).
|
||||
if (!kIsLastIteration) {
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kUnrollingSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
|
||||
shared_load_stream.fragment_b(kUnrollingSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
|
||||
/// Do the GEMM.
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
dim3 block = block_swizzle.swizzle();
|
||||
|
||||
// Scale the id.
|
||||
block.x *= Traits::OutputTile::kW;
|
||||
block.y *= Traits::OutputTile::kH;
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_stream(params, shared_storage, block);
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear(shared_storage.main_loop.clear);
|
||||
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(Traits::OutputTile::kD);
|
||||
|
||||
// If we do not have enough steps in the main loop, trigger the residue code.
|
||||
global_stream.move_to_residue<true>(params.k);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the GEMM-K dimension. It may have no impact.
|
||||
global_stream.rollback();
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedLoadStream shared_load_stream(params, shared_storage);
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename Traits::MultiplyAdd::Accumulators accumulators;
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// The loop index.
|
||||
Index outer_k = params.k - kUnroll;
|
||||
|
||||
// Enter the main loop and iterate.
|
||||
for (; outer_k > 0; outer_k -= kUnroll) {
|
||||
consume_tile<false>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Residual loop.
|
||||
for (; outer_k > -kUnroll; outer_k -= kUnroll) {
|
||||
consume_tile<true>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
|
||||
// Epilogue.
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n);
|
||||
epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators);
|
||||
}
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
231
cutlass/gemm/gemm_epilogue.h
Normal file
231
cutlass/gemm/gemm_epilogue.h
Normal file
@ -0,0 +1,231 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the epilogue phase of the GEMM kernel that efficiently updates global memory
|
||||
with
|
||||
the computed matrix product.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE bool is_zero(T x) {
|
||||
return x == T(0);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct GemmEpilogue {
|
||||
/// The traits class.
|
||||
typedef GemmEpilogueTraits_ Traits;
|
||||
/// The params.
|
||||
typedef typename Traits::Params Params;
|
||||
/// The shared storage.
|
||||
typedef typename Traits::SharedStorage SharedStorage;
|
||||
|
||||
/// The output tile.
|
||||
typedef typename Traits::OutputTile OutputTile;
|
||||
/// The number of iterations.
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
/// The accumulators.
|
||||
typedef typename Traits::Accumulators Accumulators;
|
||||
/// The scalar.
|
||||
typedef typename Traits::Scalar Scalar;
|
||||
/// The functor in charge of the math.
|
||||
typedef typename Traits::Functor Functor;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The iterator for C in global memory.
|
||||
typedef typename Traits::GlobalLoadIteratorC GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef typename Traits::GlobalTransformerC GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef typename Traits::GlobalTransformerD GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef typename Traits::GlobalStoreIteratorD GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef typename Traits::SharedStoreIteratorD SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
|
||||
/// The iterator to load D in shared memory.
|
||||
typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD;
|
||||
/// The shared load transformer for D.
|
||||
typedef Copy<typename SharedLoadIteratorD::Fragment> SharedLoadTransformerD;
|
||||
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Index m_,
|
||||
Index n_)
|
||||
: params(params_), shared_storage(shared_storage_), m(m_), n(n_) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) {
|
||||
if (is_zero(params.functor.beta)) {
|
||||
epilogue_with_or_without_beta<true>(block, accumulators);
|
||||
} else {
|
||||
epilogue_with_or_without_beta<false>(block, accumulators);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kBetaIsZero_>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
|
||||
Accumulators& accumulators) {
|
||||
|
||||
// The problem size.
|
||||
Coord<3> const bounds = cutlass::make_Coord(0, n, m);
|
||||
|
||||
// The functor.
|
||||
Functor functor(params.functor);
|
||||
// The C fragment.
|
||||
typename GlobalLoadIteratorC::Fragment fragment_c;
|
||||
// The transformed C fragment.
|
||||
typename GlobalTransformerC::OutputFragment transformed_c;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
// Compute pointer and predicate offsets for C and D global iterators.
|
||||
int const pointer_offset =
|
||||
((params.iterator_d.inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.inc_advance) *
|
||||
Iterations::kW +
|
||||
params.stride_h) *
|
||||
h;
|
||||
int const predicate_offset =
|
||||
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.predicate_inc_advance) *
|
||||
Iterations::kW +
|
||||
Traits::Delta::kH) *
|
||||
h;
|
||||
|
||||
// The iterator to load the elements of the C matrix.
|
||||
GlobalLoadIteratorC global_load_iterator(
|
||||
params.iterator_c, bounds, block, pointer_offset, predicate_offset);
|
||||
// The transformer for C.
|
||||
GlobalTransformerC transformer_c;
|
||||
// The transformer for D.
|
||||
GlobalTransformerD transformer_d;
|
||||
// The iterator to store into the D matrix.
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, bounds, block, pointer_offset, predicate_offset);
|
||||
|
||||
// The transformer to transform before storing to shared memory.
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
|
||||
// The iterator to store to shared memory.
|
||||
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
|
||||
shared_storage.shared_stream.store);
|
||||
|
||||
// The iterator to load from shared memory. TODO: Use a stream.
|
||||
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
|
||||
shared_storage.shared_stream.load);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Load the C matrix into fragment.
|
||||
if (!kBetaIsZero_) {
|
||||
iterator_load(global_load_iterator, fragment_c);
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
shared_load_fence();
|
||||
|
||||
// Copy the accumulators to shared memory.
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Copy the accumulators back to registers from shared memory.
|
||||
typename SharedLoadIteratorD::Fragment fetched_d;
|
||||
shared_iterator_load(shared_load_iterator, fetched_d);
|
||||
|
||||
// Do the math.
|
||||
typename GlobalTransformerD::InputFragment fragment_d;
|
||||
|
||||
if (kBetaIsZero_) {
|
||||
functor.evaluate(fetched_d, fragment_d);
|
||||
} else {
|
||||
// Transform C fragment.
|
||||
transformer_c.transform(fragment_c, transformed_c);
|
||||
// Do the math.
|
||||
functor.evaluate(fetched_d, transformed_c, fragment_d);
|
||||
}
|
||||
|
||||
// Transform D fragment.
|
||||
typename GlobalTransformerD::OutputFragment transformed_d;
|
||||
transformer_d.transform(fragment_d, transformed_d);
|
||||
|
||||
// Copy the results to global memory.
|
||||
iterator_store(global_store_iterator, transformed_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
CUTLASS_DEVICE void shared_load_fence() { __syncthreads(); }
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
CUTLASS_DEVICE void shared_store_fence() { __syncthreads(); }
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
331
cutlass/gemm/gemm_epilogue_traits.h
Normal file
331
cutlass/gemm/gemm_epilogue_traits.h
Normal file
@ -0,0 +1,331 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of the GEMM epilogue.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The accumulators.
|
||||
typename Accumulators_,
|
||||
/// The iterator to load C from global memory.
|
||||
typename GlobalLoadIteratorC_,
|
||||
/// The transformer for C.
|
||||
typename GlobalTransformerC_,
|
||||
/// The transformer for D.
|
||||
typename GlobalTransformerD_,
|
||||
/// The iterator to store D to global memory.
|
||||
typename GlobalStoreIteratorD_,
|
||||
/// The iterator to store D to shared memory.
|
||||
typename SharedStoreIteratorD_,
|
||||
/// The shared store transformer for D.
|
||||
typename SharedStoreTransformerD_,
|
||||
/// The iterator to load D from shared memory.
|
||||
typename SharedLoadIteratorD_,
|
||||
/// The number of iterations in the epilogue.
|
||||
typename Iterations_,
|
||||
/// The iterations strides.
|
||||
typename Delta_,
|
||||
/// The functor to be used in the epilogue.
|
||||
typename Functor_,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct GemmEpilogueTraits {
|
||||
//
|
||||
/// The output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The number of iterations.
|
||||
/// The accumulators.
|
||||
typedef Accumulators_ Accumulators;
|
||||
/// The iterator for C in global memory.
|
||||
typedef GlobalLoadIteratorC_ GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef GlobalTransformerC_ GlobalTransformerC;
|
||||
/// The transformer for D.
|
||||
typedef GlobalTransformerD_ GlobalTransformerD;
|
||||
/// The iterator for D in global memory.
|
||||
typedef GlobalStoreIteratorD_ GlobalStoreIteratorD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedLoadIteratorD_ SharedLoadIteratorD;
|
||||
/// typedef typename GemmConfig::EpilogueIterations Iterations;
|
||||
typedef Iterations_ Iterations;
|
||||
/// The iterations strides.
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// The functor in charge of the math.
|
||||
typedef Functor_ Functor;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
|
||||
/// We do not support 3D or 4D shapes.
|
||||
static_assert(Iterations::kD == 1 && Iterations::kC == 1, "Unsupported 3D/4D shapes");
|
||||
|
||||
/// The scalar.
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
/// The scalar for C.
|
||||
typedef typename GlobalLoadIteratorC::Scalar ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef typename GlobalStoreIteratorD::Scalar ScalarD;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The strides for H and W in the different iterations of the epilogue.
|
||||
Index stride_h, stride_w;
|
||||
/// The params for the C iterator.
|
||||
typename GlobalLoadIteratorC::Params iterator_c;
|
||||
/// The params for the D global iterator.
|
||||
typename GlobalStoreIteratorD::Params iterator_d;
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreIteratorD::Params shared_store_iterator_d;
|
||||
/// The params for the D shared load iterator.
|
||||
typename SharedLoadIteratorD::Params shared_load_iterator_d;
|
||||
/// The functor params.
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// At the end of the H iteration, we jump over a number of columns.
|
||||
this->stride_h = desc.ldd * Delta::kH;
|
||||
// Nothing to do here.
|
||||
this->stride_w = 0;
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = iterator_c.initialize(
|
||||
reinterpret_cast<ScalarC const*>(desc.d_c), desc.ldc, desc.n, stride_w, Delta::kW);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return iterator_d.initialize(
|
||||
reinterpret_cast<ScalarD*>(desc.d_d), desc.ldd, desc.n, stride_w, Delta::kW);
|
||||
}
|
||||
};
|
||||
|
||||
/// The shared memory storage to exchange data.
|
||||
union StreamSharedStorage {
|
||||
// The storage for the store iterator.
|
||||
typename SharedStoreIteratorD::SharedStorage store;
|
||||
// The storage for the store iterator.
|
||||
typename SharedLoadIteratorD::SharedStorage load;
|
||||
};
|
||||
|
||||
/// The shared memory to swizzle the data in the epilogue.
|
||||
struct SharedStorage {
|
||||
// The storage for the shared stream D.
|
||||
StreamSharedStorage shared_stream;
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct GemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of iterations in the epilogue.
|
||||
typedef Shape<1,
|
||||
GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
|
||||
GemmConfig_::kAccumulatorsPerLdsB,
|
||||
GemmConfig_::kAccumulatorsPerLdsB>
|
||||
Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0,
|
||||
GemmConfig_::kAccumulatorsPerLdsB*(
|
||||
GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
|
||||
0>
|
||||
Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef GemmSharedStoreTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsD,
|
||||
// The skew -- 128 / sizeof(ScalarD) / kScalarsPerStsD is the number of threads involved in
|
||||
// a single STS. We divide by 2 as our objective is to add a skew to the odd threads to
|
||||
// avoid bank conflicts between odd and even threads.
|
||||
128 / sizeof(typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
|
||||
GemmConfig_::kScalarsPerStsD>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef GemmSharedLoadTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The number of columns of the output tile written by iteration.
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// How many elements do we jump over at each iteration?
|
||||
Iterations::kW,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM config.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper to create the traits class.
|
||||
typename Helper_ = GemmEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename GemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The number of iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
182
cutlass/gemm/gemm_global_stream.h
Normal file
182
cutlass/gemm/gemm_global_stream.h
Normal file
@ -0,0 +1,182 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements efficient loading of the thread block-level tile from global memory and
|
||||
storing
|
||||
to shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/iterator_access.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_>
|
||||
|
||||
struct GlobalLoadStreamBase {
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
/// The store iterator to write to shared memory.
|
||||
typedef StoreIterator_ StoreIterator;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename LoadIterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Make sure the transformed fragment is the same as the store fragment.
|
||||
static_assert((platform::is_same<TransformedFragment, typename StoreIterator::Fragment>::value),
|
||||
"");
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = LoadIterator::kLayout;
|
||||
/// The scalar type of the iterator.
|
||||
typedef typename LoadIterator::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
// The load iterator.
|
||||
typename LoadIterator::Params load_iterator;
|
||||
// The store iterator.
|
||||
typename StoreIterator::Params store_iterator;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
|
||||
int error_code = load_iterator.initialize(desc, pointer, ld);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
return store_iterator.initialize();
|
||||
}
|
||||
};
|
||||
|
||||
/// The amount of storage in shared memory needed to store the tile.
|
||||
typedef typename StoreIterator::SharedStorage SharedStoreStorage;
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
union SharedStorage {
|
||||
// The load iterator.
|
||||
typename LoadIterator::SharedStorage load_iterator;
|
||||
// The store iterator.
|
||||
SharedStoreStorage store_iterator;
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& block)
|
||||
: load_iterator(params.load_iterator, bounds, block),
|
||||
transformer(),
|
||||
store_iterator(params.store_iterator, shared_storage.store_iterator)
|
||||
|
||||
{
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
iterator_store(store_iterator, transformed_fragment);
|
||||
store_iterator.inc_stage();
|
||||
}
|
||||
|
||||
/// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
load_iterator.residue(k);
|
||||
if (!skip_clear) {
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the GEMM-k dimension.
|
||||
CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
|
||||
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
FetchedFragment fetched_fragment;
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
/// The fragment to convert the data after it has been fetched from shared memory.
|
||||
TransformedFragment transformed_fragment;
|
||||
/// The store iterator.
|
||||
StoreIterator store_iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_ = Copy<typename LoadIterator_::Fragment> >
|
||||
|
||||
struct GlobalLoadStream : public GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> {
|
||||
/// The base class.
|
||||
typedef GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params,
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
Coord<3> const& bounds,
|
||||
Coord<3> const& block)
|
||||
: Base(params, shared_storage, bounds, block) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
541
cutlass/gemm/gemm_global_tile.h
Normal file
541
cutlass/gemm/gemm_global_tile.h
Normal file
@ -0,0 +1,541 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing to global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functor reshapes a tile of threads to match a tile of data. The idea is that when
|
||||
// the user wants to build the iterator traits, he/she may want to specify the tile independently
|
||||
// from the number of scalars loaded/stored per instruction. For example, in the row-major version
|
||||
// with a tile of size 128x8 - the user may want to that the iterator works with 32x8 threads if
|
||||
// each thread loads 1 scalar per LDG. If the user changes to 4 scalars per LDG, then the tile of
|
||||
// threads has to change. The code below detects that and correct the code automatically - it is
|
||||
// a helper when the user does not specify the right configuration.
|
||||
|
||||
template <typename Tile_, typename Threads_, bool = (Tile_::kW < Threads_::kW)>
|
||||
struct ReshapeThreads {
|
||||
typedef Threads_ Threads;
|
||||
};
|
||||
|
||||
template <typename Tile_, typename Threads_>
|
||||
struct ReshapeThreads<Tile_, Threads_, true> {
|
||||
typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1> Threads;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct GemmGlobalTileTraits {
|
||||
/// Identity of the operand
|
||||
static GemmOperand::Kind const kOperand = kOperand_;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kAccessSize_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
|
||||
|
||||
/// The tile shape
|
||||
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile Tile;
|
||||
/// The threads shape
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The relative offset between two elements in the H/W dimension in adjacent threads.
|
||||
typedef Shape<1, 1, Tile::kC> ThreadsDelta;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize>
|
||||
Iterations;
|
||||
|
||||
typedef GemmMultiplicandTraits<Tile, kOperand, kLayout> MultiplicandTraits;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kStrideH_, int kAccessSize_>
|
||||
struct GemmGlobalTileCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// The stride in the H dimension.
|
||||
static int const kStrideH = kStrideH_;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
typedef typename Base::Threads Threads;
|
||||
|
||||
typedef typename Base::ThreadsDelta ThreadsDelta;
|
||||
|
||||
typedef typename Base::ImmediateOffsetStrides ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * kStrideH * Iterations::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorAb
|
||||
: public TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> This_; /// The base class.
|
||||
|
||||
typedef TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
: IteratorAdvance::kW,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
/// Fragment type loaded by the iterator
|
||||
typedef typename Base::Fragment Fragment;
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
typedef cutlass::PredicateVector<ShapeCount<typename Base::Iterations>::kCount> PredicateVector;
|
||||
|
||||
/// Iterator parameters type
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
|
||||
Index inc_d = 0;
|
||||
Index inc_advance = 0;
|
||||
// Move by some columns for each iteration in the H dimension.
|
||||
Index inc_h = Base::Delta::kH * stride_h;
|
||||
|
||||
// Move by some more columns in the number of iterations if the D dimension is > 1.
|
||||
if (Base::Delta::kD > 0) {
|
||||
inc_d = Base::Delta::kD * stride_h - (Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// Move to the beginning of the next iteration.
|
||||
if (kAdvance == IteratorAdvance::kH && Base::Delta::kD > 0) {
|
||||
inc_advance = inc_d;
|
||||
} else if (kAdvance == IteratorAdvance::kH) {
|
||||
inc_advance = inc_h;
|
||||
} else if (Base::Delta::kD > 0) {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h -
|
||||
(Base::Iterations::kD - 1) * Base::Delta::kD * stride_h;
|
||||
} else {
|
||||
inc_advance = (Base::Iterations::kW + 0) * ShapeCount<typename Base::Delta>::kWc -
|
||||
(Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// Move to the residue.
|
||||
Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
// The jump in the gemm-k dimension.
|
||||
Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
|
||||
|
||||
// Compute the offset to the residue and how to "come" back.
|
||||
Index const kResidue = desc.k % kBlock;
|
||||
if (kResidue > 0) {
|
||||
move_to_residue_offset = (desc.k - kResidue) * stride;
|
||||
} else {
|
||||
move_to_residue_offset = (desc.k - kBlock) * stride;
|
||||
}
|
||||
|
||||
Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The extra offset to control moving to the residue.
|
||||
Index move_to_residue_offset;
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// The column.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Add the blocks indices.
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_h += block[1];
|
||||
block_w += block[2];
|
||||
|
||||
} else {
|
||||
block_h += block[2];
|
||||
block_w += block[1];
|
||||
}
|
||||
|
||||
// Setup the pointer.
|
||||
params.pointer += (block_h * params.stride_h + block_w);
|
||||
|
||||
// Initialize predicates
|
||||
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Initialize the predicates.
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
|
||||
// Setup the masks to control loads.
|
||||
predicates.fill(0);
|
||||
|
||||
int bounds_h, bounds_w;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
bounds_w = bounds[2] - block[2];
|
||||
bounds_h = bounds[1];
|
||||
|
||||
} else {
|
||||
bounds_w = bounds[1];
|
||||
bounds_h = bounds[2] - block[1];
|
||||
}
|
||||
|
||||
// Fill in the bits of the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
bool flag = w * Base::Delta::kW < bounds_w;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
|
||||
} else {
|
||||
flag = flag && (h * Base::Delta::kH) < bounds_h;
|
||||
}
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
predicates.set(bit, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
// Store the pointer and the predicates.
|
||||
stored_pointer = params.pointer;
|
||||
stored_predicates = predicates;
|
||||
|
||||
// Move the pointer to the residue.
|
||||
params.pointer += params.move_to_residue_offset;
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// The unrolling factor.
|
||||
int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
|
||||
// Clear the predicates for the residue. TODO: We can do something smarter.
|
||||
int const kResidue = (int)(k % (Index)kUnroll);
|
||||
if (kResidue > 0) {
|
||||
residue(kResidue);
|
||||
}
|
||||
}
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_DEVICE void residue(Index k) {
|
||||
// The coordinates of the thread.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
|
||||
// Update the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
Index offset = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
offset += block_h + h * Base::Delta::kH + d * Base::Delta::kD;
|
||||
} else {
|
||||
offset += block_w + w * Base::Delta::kW;
|
||||
}
|
||||
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
if (offset >= k) {
|
||||
predicates.set(bit, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
params.pointer = stored_pointer;
|
||||
predicates = stored_predicates;
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
return predicates[bit];
|
||||
}
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The pointer.
|
||||
typename Base::Scalar const* stored_pointer;
|
||||
/// The predicates.
|
||||
PredicateVector predicates, stored_predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_advance, inc_h;
|
||||
/// The strides to increment the predicate offset
|
||||
Index predicate_inc_advance, predicate_inc_h;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Each column of the matrix.
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ld;
|
||||
// Each thread output 1 column per iteration. The stride between columns is given by the
|
||||
// number of scalars that are loaded per LDS for B.
|
||||
inc_h = ld * TileTraits_::kStrideH;
|
||||
inc_advance =
|
||||
(ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
|
||||
|
||||
predicate_offset = bound;
|
||||
predicate_inc_h = TileTraits_::kStrideH;
|
||||
predicate_inc_advance =
|
||||
-((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int offset = 0,
|
||||
int pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
this->params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the validity of the iterator.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
141
cutlass/gemm/gemm_operand.h
Normal file
141
cutlass/gemm/gemm_operand.h
Normal file
@ -0,0 +1,141 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear
|
||||
memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to describe attributes of GEMM matrix operands
|
||||
template <GemmOperand::Kind kOperand_, MatrixLayout::Kind kLayout_>
|
||||
struct GemmOperandTraitsAb {
|
||||
static const bool Congruous =
|
||||
(kOperand_ == GemmOperand::kA ^ kLayout_ == MatrixLayout::kRowMajor);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmOperand::Kind kOperand_, typename Tile_>
|
||||
struct GetExtent;
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kA, Tile_> {
|
||||
static const int kExtent = Tile_::kW;
|
||||
};
|
||||
|
||||
template <typename Tile_>
|
||||
struct GetExtent<GemmOperand::kB, Tile_> {
|
||||
static const int kExtent = Tile_::kH;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Determines the shape of a multiplicand tile in terms of strided (H) and contiguous (W)
|
||||
/// dimensions
|
||||
template <typename ThreadBlockTile_, GemmOperand::Kind Usage, MatrixLayout::Kind Layout>
|
||||
struct GemmMultiplicandTraits {
|
||||
// Only defined for A or B
|
||||
static_assert(Usage == GemmOperand::kA || Usage == GemmOperand::kB,
|
||||
"MultiplicandTileShape defined only for A or B operands.");
|
||||
|
||||
/// Shape of GEMM thread block tile (K, N, M)
|
||||
typedef ThreadBlockTile_ ThreadBlockTile;
|
||||
|
||||
/// Identifies multiplicand
|
||||
static GemmOperand::Kind const kUsage = Usage;
|
||||
|
||||
/// Layout of tile
|
||||
static MatrixLayout::Kind const kLayout = Layout;
|
||||
|
||||
// True if K is the strided dimension
|
||||
static bool const kKstrided = (kUsage == GemmOperand::kA ^ kLayout == MatrixLayout::kRowMajor);
|
||||
|
||||
/// Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand
|
||||
typedef typename platform::conditional<
|
||||
kKstrided,
|
||||
Shape<1, ThreadBlockTile::kD, GetExtent<Usage, ThreadBlockTile>::kExtent>,
|
||||
Shape<1, GetExtent<Usage, ThreadBlockTile>::kExtent, ThreadBlockTile::kD> >::type Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Project's a coordinate (K, N, M) onto inner and outer dimensions defined for each
|
||||
/// operand.
|
||||
template <GemmOperand::Kind operand, bool Kstrided = true>
|
||||
struct ProjectOperand;
|
||||
|
||||
/// Project A operand - (0, K, M)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kA, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[2]);
|
||||
} else {
|
||||
return make_Coord(0, coord[2], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project B operand - (0, K, N)
|
||||
template <bool Kstrided>
|
||||
struct ProjectOperand<GemmOperand::kB, Kstrided> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) {
|
||||
if (Kstrided) {
|
||||
return make_Coord(0, coord[0], coord[1]);
|
||||
} else {
|
||||
return make_Coord(0, coord[1], coord[0]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Project C operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kC, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
/// Project D operand - (0, N, M)
|
||||
template <>
|
||||
struct ProjectOperand<GemmOperand::kD, true> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
113
cutlass/gemm/gemm_shared_stream.h
Normal file
113
cutlass/gemm/gemm_shared_stream.h
Normal file
@ -0,0 +1,113 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for managing loading and storing fragments to shared memory in the
|
||||
efficient GEMM pipeline.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename Iterator_,
|
||||
/// The transformer to be applied after the data has been copied from shared memory.
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
|
||||
struct SharedLoadStream {
|
||||
/// The load iterator.
|
||||
typedef Iterator_ Iterator;
|
||||
/// The transformer.
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// The fragment that is copied from shared memory.
|
||||
typedef typename Iterator::Fragment FetchedFragment;
|
||||
/// The fragment that is obtained after the transformation by the transformer.
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
/// Make sure the fragments match.
|
||||
static_assert((platform::is_same<FetchedFragment, typename Transformer::InputFragment>::value),
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The iterator params.
|
||||
typename Iterator::Params iterator;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize() { return iterator.initialize(); }
|
||||
};
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
typedef typename Iterator::Storage SharedStorage;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
this->initialize(params, shared_storage);
|
||||
}
|
||||
|
||||
/// Initialize the stream.
|
||||
CUTLASS_DEVICE void initialize(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
// The iterator.
|
||||
iterator = Iterator(params.iterator, shared_storage);
|
||||
// The transformer.
|
||||
transformer = Transformer();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(FetchedFragment &fetched) { shared_iterator_load(iterator, fetched); }
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int d, FetchedFragment &fetched) {
|
||||
shared_iterator_load(iterator, fetched, d);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(FetchedFragment &fetched, TransformedFragment &transformed) {
|
||||
transformer.transform(fetched, transformed);
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() { iterator.inc_stage(); }
|
||||
|
||||
/// The iterator.
|
||||
Iterator iterator;
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
417
cutlass/gemm/gemm_shared_tile.h
Normal file
417
cutlass/gemm/gemm_shared_tile.h
Normal file
@ -0,0 +1,417 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators for efficiently loading and storing tiles to and from shared memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_>
|
||||
struct GemmSharedStoreTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Tile::kC, kScalarsPerSts_> ThreadsStrides;
|
||||
/// The skew.
|
||||
static int const kSkew = 0;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1,
|
||||
Tile::kH / Threads::kH,
|
||||
Tile::kW / Threads::kW,
|
||||
Tile::kC / Threads::kC / kAccessSize>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kAccessSize>
|
||||
ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerSts_, int kSkew_>
|
||||
struct GemmSharedStoreWithSkewTileAbTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skews.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerSts_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Shape<Tile_::kD, Tile_::kH, Tile_::kW + kSkew_>,
|
||||
kScalarsPerSts_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef Threads_ Threads;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The number of scalars per STS.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, TileWithoutSkew::kH / Threads::kW, TileWithoutSkew::kW / Threads::kH> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, ShapeCount<Tile>::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
int offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
/// The strides to compute the base position of the thread.
|
||||
typedef Shape<0, kScalarsPerSts_, ShapeCount<Tile>::kHwc / Threads::kW> ThreadsStrides;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileATraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
// static int const kScalarsPerLds = kScalarsPerLds_;
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_row = (threadIdx.x & 0x0e) / 2;
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate vector.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename InstructionShape_,
|
||||
int kStages_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileBTraits {
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The tile without skew.
|
||||
typedef Shape<kStages_,
|
||||
OutputTile_::kD / InstructionShape_::kD,
|
||||
GetExtent<kOperand, OutputTile_>::kExtent * InstructionShape_::kD>
|
||||
TileWithoutSkew_;
|
||||
/// The tile with skew.
|
||||
typedef Shape<kStages_, TileWithoutSkew_::kH, TileWithoutSkew_::kW + kSkew_> TileWithSkew;
|
||||
/// The tile without skew after reshaping.
|
||||
typedef typename ReshapeTile<TileWithoutSkew_, kScalarsPerLds_>::Tile TileWithoutSkew;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<TileWithSkew, kScalarsPerLds_>::Tile Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in a warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of warps.
|
||||
static int const kWarps = GetExtent<kOperand, Warps>::kExtent;
|
||||
/// The number of threads in one dimension of the warp.
|
||||
static int const kThreadsPerWarp = GetExtent<kOperand, ThreadsPerWarp>::kExtent;
|
||||
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0>
|
||||
ImmediateOffsetStrides;
|
||||
typedef Shape<TileWithSkew::kW * Warps::kD, 0, kWarps * kThreadsPerWarp * kAccessSize, 0> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Extract the warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// Extract the slice.
|
||||
int const slice = warp / (Warps::kH * Warps::kW);
|
||||
// The warp in the slice.
|
||||
int const warp_in_slice = warp % (Warps::kH * Warps::kW);
|
||||
// Compute the row offset for each warp.
|
||||
int const warp_col = warp_in_slice / Warps::kW;
|
||||
// Compute the row offset for each thread.
|
||||
int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01);
|
||||
// The offset.
|
||||
int const offset =
|
||||
slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize;
|
||||
// Embed the offset in a 4D coordinate.
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kScalarsPerSts_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedStoreTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerSts_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, kScalarsPerThread / kAccessSize> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * ThreadsPerWarp::kW * kAccessSize> ImmediateOffsetStrides;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// The warp.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
|
||||
// The position of the warp in the 2D tile.
|
||||
int const warp_row = warp % Warps::kW;
|
||||
int const warp_col = warp / Warps::kW;
|
||||
|
||||
// We assume that the elements are distributed in a warps as 4 columns of 8 elements. The
|
||||
// columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15],
|
||||
// col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31].
|
||||
int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW;
|
||||
int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row;
|
||||
|
||||
// Odd threads go to the second half of shared memory.
|
||||
int const row = threadIdx.x & 0x01;
|
||||
int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW +
|
||||
lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset;
|
||||
// Embed the offset in a 4D coords.
|
||||
return make_Coord(0, 0, row * kScalarsPerRow + col, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename ThreadsPerWarp_,
|
||||
int kTileH_,
|
||||
int kScalarsPerLds_,
|
||||
int kSkew_ = 0>
|
||||
struct GemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef typename platform::remove_const<Scalar_>::type Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar_* Pointer;
|
||||
/// The dimension of the output tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The warps in the tile.
|
||||
typedef Warps_ Warps;
|
||||
/// The threads in the warps.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of scalars per LDG/STG.
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The number of scalars per thread.
|
||||
static int const kScalarsPerThread = OutputTile_::kW / Warps::kW / ThreadsPerWarp::kW;
|
||||
/// The number of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
/// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts).
|
||||
static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew;
|
||||
|
||||
/// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank
|
||||
/// conflicts in the epilogue.
|
||||
typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile;
|
||||
|
||||
// Compute the number of iterations per warp in the Tile::kH dimension.
|
||||
static int const kIterationsInHPerWarp = kTileH_ / ShapeCount<Warps>::kCount;
|
||||
|
||||
// As explained above, the shared memory tile is composed of 2 rows and each rows is made of
|
||||
// kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go
|
||||
// back to the 1st row. To model that scheme we define the Iterations shape as Shape<X, 2, ...>.
|
||||
// However, in some cases, we have only 1 iteration per warp. In that case, we must define the
|
||||
// shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension
|
||||
// to keep the number of elements to reduce for split-K.
|
||||
static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2;
|
||||
// As soon as we know kIterationsH, it is trivial to compute kIterationsD:
|
||||
static int const kIterationsD = kIterationsInHPerWarp / kIterationsH;
|
||||
|
||||
// If we have split-K enabled, we have to jump over the elements from the "odd/even" column of
|
||||
// threads to grab the other elements.
|
||||
static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH;
|
||||
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<kIterationsD, kIterationsH, OutputTile::kW / kWarpSize / kAccessSize, Warps::kD>
|
||||
Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK>
|
||||
ImmediateOffsetStrides;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<OutputTile::kW, kScalarsPerRow, kWarpSize * kAccessSize, kSplitK> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE Coord<4> operator()() const {
|
||||
// Each warp works on a different column.
|
||||
int const h = threadIdx.x / kWarpSize;
|
||||
// Compute the row.
|
||||
int const w = (threadIdx.x & (kWarpSize - 1)) * kAccessSize;
|
||||
int offset = 0;
|
||||
if (Iterations::kH == 1) {
|
||||
int const row = h & 0x1;
|
||||
int const col = h / 2;
|
||||
offset = row * ShapeCount<Tile>::kWc + col * OutputTile::kW * Iterations::kD + w;
|
||||
} else {
|
||||
offset = h * OutputTile::kW * Iterations::kD + w;
|
||||
}
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
818
cutlass/gemm/gemm_traits.h
Normal file
818
cutlass/gemm/gemm_traits.h
Normal file
@ -0,0 +1,818 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of complete GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/clear_accumulators.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/identity_block_swizzle.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The scalar type for A.
|
||||
typename ScalarA_,
|
||||
/// The scalar type for B.
|
||||
typename ScalarB_,
|
||||
/// The scalar type for C.
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D.
|
||||
typename ScalarD_,
|
||||
/// The output tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math.
|
||||
typename MultiplyAdd_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
int kScalarsPerStsA_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
int kScalarsPerStsB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int kScalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_,
|
||||
/// Do we do the residue in the prologue?
|
||||
bool kResidueInPrologue_ = false>
|
||||
|
||||
struct GemmConfig {
|
||||
//
|
||||
/// The scalar for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// The tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The functor to do D = A*B + C.
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
/// The shape of the instruction.
|
||||
typedef typename MultiplyAdd::InstructionShape InstructionShape;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
|
||||
/// The accumulators.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// The number of warps.
|
||||
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
|
||||
/// The default warp size (32 threads per warp).
|
||||
static int const kWarpSize = cutlass::kWarpSize;
|
||||
/// The numnber of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
|
||||
static int const kScalarsPerStsA = kScalarsPerStsA_;
|
||||
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
|
||||
static int const kScalarsPerStsB = kScalarsPerStsB_;
|
||||
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
|
||||
|
||||
/// The number of scalars per LDG for C.
|
||||
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
|
||||
|
||||
/// The number of scalars per STS/LDS/STG for D.
|
||||
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
|
||||
static int const kScalarsPerStsD = kScalarsPerStsD_;
|
||||
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
|
||||
|
||||
/// The number of accumulators that are going to be fed from one LDS A/B.
|
||||
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
|
||||
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// Do we do the residue in the prologue?
|
||||
static bool const kResidueInPrologue = kResidueInPrologue_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for A.
|
||||
static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsA,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is column-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The number of scalars in 4B.
|
||||
static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
|
||||
GlobalTileTraits::Threads::kW * kScalarsIn4B;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
GemmConfig_::kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// B is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^T.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^T.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
MultiplyAddScalar const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsB,
|
||||
// The skew.
|
||||
0>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_, bool kResidueInPrologue_ = GemmTraits_::kResidueInPrologue>
|
||||
struct GemmResidue {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have
|
||||
// complete main loops after that. It helps simplify the logic in the main loop.
|
||||
if (kIsPrologue) {
|
||||
stream_a.move_to_residue(k);
|
||||
stream_b.move_to_residue(k);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct GemmResidue<GemmTraits_, false> {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The index.
|
||||
typedef typename GemmTraits_::Index Index;
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(GemmTraits_::OutputTile::kD);
|
||||
|
||||
// Call the residue code. That's the same path as CUTLASS 1.0.0.
|
||||
if (kIsPrologue && k < kUnroll) {
|
||||
stream_a.residue(k, true);
|
||||
stream_b.residue(k, true);
|
||||
} else if (k <= kUnroll) {
|
||||
stream_a.residue(k, false);
|
||||
stream_b.residue(k, false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM configuration.
|
||||
typename GemmConfig_,
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typename GlobalLoadStreamA_,
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typename GlobalLoadStreamB_,
|
||||
/// The stream to load A from shared memory.
|
||||
typename SharedLoadStreamA_,
|
||||
/// The stream to load B from shared memory.
|
||||
typename SharedLoadStreamB_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typename BlockSwizzle_ = IdentityBlockSwizzle,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The tool used to clear accumulators.
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Scalar> >
|
||||
|
||||
struct GemmTraits {
|
||||
/// This class.
|
||||
typedef GemmTraits<GemmConfig_,
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_>
|
||||
This_;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig::OutputTile OutputTile;
|
||||
/// Is the residue treated in the prologue?
|
||||
static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue;
|
||||
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
|
||||
/// The layout of A.
|
||||
static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
|
||||
/// The scalar for A.
|
||||
typedef typename GlobalLoadStreamA_::Scalar ScalarA;
|
||||
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStreamB_ GlobalLoadStreamB;
|
||||
/// The layout of B.
|
||||
static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
|
||||
/// The scalar for B.
|
||||
typedef typename GlobalLoadStreamB_::Scalar ScalarB;
|
||||
|
||||
/// The iterator for A to load from shared memory.
|
||||
typedef SharedLoadStreamA_ SharedLoadStreamA;
|
||||
/// The iterator for B to load from shared memory.
|
||||
typedef SharedLoadStreamB_ SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The epilogue.
|
||||
typedef Epilogue_ Epilogue;
|
||||
/// The scalars in the epilogue.
|
||||
typedef typename Epilogue::ScalarC ScalarC;
|
||||
typedef typename Epilogue::ScalarD ScalarD;
|
||||
|
||||
/// The block swizzle to reorganize the grid.
|
||||
typedef BlockSwizzle_ BlockSwizzle;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// Clear the accumulators.
|
||||
typedef ClearAccumulators_ ClearAccumulators;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n, k;
|
||||
/// The params for the A stream.
|
||||
typename GlobalLoadStreamA::Params global_stream_a;
|
||||
/// The params for the B stream.
|
||||
typename GlobalLoadStreamB::Params global_stream_b;
|
||||
/// The params for the A stream from shared memory.
|
||||
typename SharedLoadStreamA::Params shared_stream_a;
|
||||
/// The params for the B stream from shared memory.
|
||||
typename SharedLoadStreamB::Params shared_stream_b;
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// Set the problem size.
|
||||
this->m = desc.m;
|
||||
this->n = desc.n;
|
||||
this->k = desc.k;
|
||||
|
||||
// Initialize the iterator for A.
|
||||
int error_code =
|
||||
global_stream_a.initialize(desc, reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Initialize the iterator for B.
|
||||
error_code =
|
||||
global_stream_b.initialize(desc, reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for A.
|
||||
template <typename GlobalLoadStream_, typename SharedLoadStream_>
|
||||
union StreamSharedStorage {
|
||||
// The storage needed by the global stream.
|
||||
typename GlobalLoadStream_::SharedStorage global;
|
||||
// The storage needed by the shared stream.
|
||||
typename SharedLoadStream_::SharedStorage shared;
|
||||
};
|
||||
|
||||
// The storage for the main loop + prologue.
|
||||
struct MainLoopSharedStorage {
|
||||
// The storage to shuffle the A matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamA, SharedLoadStreamA> stream_a;
|
||||
// The storage to shuffle the B matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamB, SharedLoadStreamB> stream_b;
|
||||
// The storage to clear the accumulators if needed.
|
||||
typename ClearAccumulators::SharedStorage clear;
|
||||
};
|
||||
|
||||
/// The storage in shared memory.
|
||||
union SharedStorage {
|
||||
// The storage for the main loop.
|
||||
MainLoopSharedStorage main_loop;
|
||||
// The storage for the epilogue.
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// Assemble the global load streams for A/B.
|
||||
struct GlobalLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
dim3 const& block)
|
||||
: stream_a(params.global_stream_a,
|
||||
shared_storage.main_loop.stream_a.global,
|
||||
cutlass::make_Coord(0, params.k, params.m),
|
||||
cutlass::make_Coord(0, 0, block.x)),
|
||||
stream_b(params.global_stream_b,
|
||||
shared_storage.main_loop.stream_b.global,
|
||||
cutlass::make_Coord(0, params.k, params.n),
|
||||
make_Coord(0, 0, block.y)) {}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
stream_a.copy();
|
||||
stream_b.copy();
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
stream_b.commit();
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
GemmResidue<This_>::move_to_residue<kIsPrologue>(stream_a, stream_b, k);
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() { GemmResidue<This_>::rollback(stream_a, stream_b); }
|
||||
|
||||
/// The stream for A.
|
||||
GlobalLoadStreamA stream_a;
|
||||
/// The stream for B.
|
||||
GlobalLoadStreamB stream_b;
|
||||
};
|
||||
|
||||
/// Assemble the shared load stream for A/B.
|
||||
struct SharedLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const& params, SharedStorage& shared_storage) {
|
||||
stream_a.initialize(params.shared_stream_a, shared_storage.main_loop.stream_a.shared);
|
||||
stream_b.initialize(params.shared_stream_b, shared_storage.main_loop.stream_b.shared);
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
stream_a.copy(step, fetched_a[step % 2]);
|
||||
stream_b.copy(step, fetched_b[step % 2]);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
stream_a.commit(fetched_a[step % 2], transformed_a[step % 2]);
|
||||
stream_b.commit(fetched_b[step % 2], transformed_b[step % 2]);
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamA::Fragment const& fragment_a(int step) const {
|
||||
return transformed_a[step % 2];
|
||||
}
|
||||
|
||||
/// The fragment B.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamB::Fragment const& fragment_b(int step) const {
|
||||
return transformed_b[step % 2];
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
stream_a.inc_stage();
|
||||
stream_b.inc_stage();
|
||||
}
|
||||
|
||||
/// The stream for A.
|
||||
SharedLoadStreamA stream_a;
|
||||
/// The fragments to fetch A.
|
||||
typename SharedLoadStreamA::FetchedFragment fetched_a[2];
|
||||
/// The fragments to transform A.
|
||||
typename SharedLoadStreamA::TransformedFragment transformed_a[2];
|
||||
/// The stream for B.
|
||||
SharedLoadStreamB stream_b;
|
||||
/// The fragments to fetch B.
|
||||
typename SharedLoadStreamB::FetchedFragment fetched_b[2];
|
||||
/// The fragments to transform B.
|
||||
typename SharedLoadStreamB::TransformedFragment transformed_b[2];
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
|
||||
SharedLoadStreamB::Iterator::kRequiresLoadFence) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
|
||||
struct SimplifiedGemmTraitsHelper {
|
||||
/// The global iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The data converter for A before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The global iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
/// The data converter for B before storing to shared memory.
|
||||
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The config for the GEMM.
|
||||
typename GemmConfig_,
|
||||
/// The epilogue.
|
||||
typename Epilogue_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
// The configuration for the A matrix.
|
||||
typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
|
||||
// The configuration for the B matrix.
|
||||
typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
|
||||
// The helper class to create the streams and iterators.
|
||||
typename Helper_ =
|
||||
SimplifiedGemmTraitsHelper<GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_> >
|
||||
struct SimplifiedGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
Epilogue_,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,436 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Abstraction for enumerating \p block_task within an input matrix
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster_strategy
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Strategies for enumerating \p block_task within an input matrix
|
||||
*/
|
||||
struct grid_raster_strategy
|
||||
{
|
||||
/// \brief Enumerants
|
||||
enum kind_t
|
||||
{
|
||||
/**
|
||||
* Default \p block_task assignment (currently ColumnMajor for N*,
|
||||
* RowMajor for TT, and TiledCohort for TN)
|
||||
*/
|
||||
Default,
|
||||
|
||||
/**
|
||||
* Column-major \p block_task assignment
|
||||
*/
|
||||
ColumnMajor,
|
||||
|
||||
/**
|
||||
* Row-major \p block_task assignment
|
||||
*/
|
||||
RowMajor,
|
||||
|
||||
/**
|
||||
* Two-level \p block_task assignment (both column-major)
|
||||
*/
|
||||
TiledCohort,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
*
|
||||
* NB: This generic class is not directly constructible. Algorithm-specific
|
||||
* template specializations will provide the API functionality prescribed here.
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
|
||||
grid_raster_strategy::kind_t RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
struct grid_raster
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Device API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread block's base item coordinates (x, y) in matrix C
|
||||
int2 block_item_coords;
|
||||
|
||||
/// Constructor
|
||||
grid_raster();
|
||||
|
||||
/// Whether the thread block base coordinates are out-of-bounds for an m*n matrix C
|
||||
bool is_block_oob(int m, int n);
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Grid launch API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
|
||||
static dim3 grid_dims(int m, int n);
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster (ColumnMajor specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (ColumnMajor specialization)
|
||||
*
|
||||
* Maps thread blocksin column-major fashion
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
grid_raster_strategy::ColumnMajor> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Device API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread block's base item coordinates (x, y) in matrix C
|
||||
int2 block_item_coords;
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
grid_raster()
|
||||
{
|
||||
// blockDim.x is the fastest changing grid dim on current architectures
|
||||
block_item_coords = make_int2(
|
||||
BlockItemsX * blockIdx.y,
|
||||
BlockItemsY * blockIdx.x);
|
||||
}
|
||||
|
||||
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
|
||||
inline __device__
|
||||
bool is_block_oob(int m, int n)
|
||||
{
|
||||
// ColumnMajor never rasterizes fully out-of-bounds thread blocks
|
||||
return false;
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Grid launch API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
|
||||
inline __host__ __device__
|
||||
static dim3 grid_dims(int m, int n)
|
||||
{
|
||||
// blockDim.x is the fastest changing grid dim on current architectures
|
||||
return dim3(
|
||||
(m + BlockItemsY - 1) / BlockItemsY,
|
||||
(n + BlockItemsX - 1) / BlockItemsX);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster (RowMajor specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (RowMajor specialization)
|
||||
*
|
||||
* Enumerates \p block_task in row-major fashion
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
grid_raster_strategy::RowMajor> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
{
|
||||
//-------------------------------------------------------------------------
|
||||
// Device API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread block's base item coordinates (x, y) in matrix C
|
||||
int2 block_item_coords;
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
grid_raster()
|
||||
{
|
||||
// blockDim.x is the fastest changing grid dim on current architectures
|
||||
block_item_coords = make_int2(
|
||||
BlockItemsX * blockIdx.x,
|
||||
BlockItemsY * blockIdx.y);
|
||||
}
|
||||
|
||||
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
|
||||
inline __device__
|
||||
bool is_block_oob(int m, int n)
|
||||
{
|
||||
// RowMajor never rasterizes fully out-of-bounds thread blocks
|
||||
return false;
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Grid launch API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
|
||||
inline __host__ __device__
|
||||
static dim3 grid_dims(int m, int n)
|
||||
{
|
||||
// blockDim.x is the fastest changing grid dim on current architectures
|
||||
return dim3(
|
||||
(n + BlockItemsX - 1) / BlockItemsX,
|
||||
(m + BlockItemsY - 1) / BlockItemsY);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster (TiledCohort specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (TiledCohort specialization)
|
||||
*
|
||||
* Enumerates \p block_task in column-major fashion across "cohort" tiles (where
|
||||
* cohorts are CohortBlocksY high and CohortBlocksX wide), and enumerates cohorts
|
||||
* across the matrix in column-major fashion.
|
||||
*
|
||||
* Grid layout:
|
||||
* - gridDim.y is the height of the grid in cohorts
|
||||
* - gridDim.x is the width of the grid in cohorts multiplied by the number of
|
||||
* thread blocks per cohort
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
TransformA,
|
||||
TransformB,
|
||||
grid_raster_strategy::TiledCohort> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
{
|
||||
enum
|
||||
{
|
||||
/// Height in thread blocks of a grid rasterization cohort
|
||||
CohortBlocksY = 2,
|
||||
|
||||
/// Width in thread blocks of a grid rasterization cohort
|
||||
CohortBlocksX = 2,
|
||||
|
||||
/// Number of thread blocks per cohort
|
||||
BlocksPerCohort = CohortBlocksY * CohortBlocksX,
|
||||
|
||||
/// Height in items of a grid rasterization cohort
|
||||
CohortItemsY = CohortBlocksY * BlockItemsY,
|
||||
|
||||
/// Width in items of a grid rasterization cohort
|
||||
CohortItemsX = CohortBlocksX * BlockItemsX,
|
||||
|
||||
};
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Device API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread block's base item coordinates (x, y) in matrix C
|
||||
int2 block_item_coords;
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
grid_raster()
|
||||
{
|
||||
int block_idx_cohort = blockIdx.x % BlocksPerCohort;
|
||||
int2 cohort_coords_grid = make_int2(
|
||||
blockIdx.x / BlocksPerCohort,
|
||||
blockIdx.y);
|
||||
|
||||
// Cohort is rastered in column-major order
|
||||
int2 block_coords_cohort = make_int2(
|
||||
block_idx_cohort / CohortBlocksY,
|
||||
block_idx_cohort % CohortBlocksY);
|
||||
|
||||
block_item_coords = make_int2(
|
||||
((cohort_coords_grid.x * CohortBlocksX) + block_coords_cohort.x) * BlockItemsX,
|
||||
((cohort_coords_grid.y * CohortBlocksY) + block_coords_cohort.y) * BlockItemsY);
|
||||
}
|
||||
|
||||
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
|
||||
inline __device__
|
||||
bool is_block_oob(int m, int n)
|
||||
{
|
||||
/// thread blocks within the cohort may be fully out-of-bounds
|
||||
return (block_item_coords.x >= n) || (block_item_coords.y >= m);
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Grid launch API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
|
||||
inline __host__ __device__
|
||||
static dim3 grid_dims(int m, int n)
|
||||
{
|
||||
// Extents of C matrix in cohorts
|
||||
int2 grid_cohort_dims = make_int2(
|
||||
(n + CohortItemsX - 1) / CohortItemsX,
|
||||
(m + CohortItemsY - 1) / CohortItemsY);
|
||||
|
||||
return dim3(
|
||||
grid_cohort_dims.x * BlocksPerCohort, // gridDim.x is width of grid in cohorts * size of cohort in blocks
|
||||
grid_cohort_dims.y, // gridDim.y is height of grid in cohorts
|
||||
1); // gridDim.z is reserved for optional k-splitting
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* grid_raster (Default specializations)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (Default N* specialization)
|
||||
*
|
||||
* Maps thread blocksin column-major fashion
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
|
||||
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::NonTranspose, ///< View transform enumerant for matrix A
|
||||
TransformB,
|
||||
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
:
|
||||
grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::NonTranspose,
|
||||
TransformB,
|
||||
grid_raster_strategy::ColumnMajor>
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (Default TT specialization)
|
||||
*
|
||||
* Maps thread blocksin row-major fashion
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX> ///< Width in columns of a block-wide tile in matrix C
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::Transpose, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::Transpose, ///< View transform enumerant for matrix B
|
||||
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
:
|
||||
grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::Transpose,
|
||||
matrix_transform_t::Transpose,
|
||||
grid_raster_strategy::RowMajor>
|
||||
{};
|
||||
|
||||
|
||||
/**
|
||||
* \brief Abstraction for enumerating \p block_task within an input matrix
|
||||
* (Default TN specialization)
|
||||
*
|
||||
* Maps thread blocksin blocked cohorts
|
||||
*/
|
||||
template <
|
||||
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
|
||||
int BlockItemsX> ///< Width in columns of a block-wide tile in matrix C
|
||||
struct grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::Transpose, ///< View transform enumerant for matrix A
|
||||
matrix_transform_t::NonTranspose, ///< View transform enumerant for matrix B
|
||||
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
|
||||
:
|
||||
grid_raster<
|
||||
BlockItemsY,
|
||||
BlockItemsX,
|
||||
matrix_transform_t::Transpose,
|
||||
matrix_transform_t::NonTranspose,
|
||||
grid_raster_strategy::TiledCohort>
|
||||
{};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
90
cutlass/gemm/hgemm_global_tile.h
Normal file
90
cutlass/gemm/hgemm_global_tile.h
Normal file
@ -0,0 +1,90 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tile traits used to construct global tile iterator for HGEMM. This is intended to
|
||||
partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate
|
||||
memory accesses larger than 16 bits.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 2, Base::Tile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 2, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 2,
|
||||
2,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
104
cutlass/gemm/hgemm_multiply_add.h
Normal file
104
cutlass/gemm/hgemm_multiply_add.h
Normal file
@ -0,0 +1,104 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Specialization implementing multiply-add operation on half-precision floating point
|
||||
fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<half, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Make sure there's an even number of elements in both dimensions.
|
||||
static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
|
||||
static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// The inputs.
|
||||
__half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
|
||||
// The offsets in the output fragment.
|
||||
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
|
||||
// Compute the product a[i] * b[j].H0_H0.
|
||||
d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
|
||||
// Compute the product a[i] * b[j].H1_H1.
|
||||
d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
94
cutlass/gemm/hgemm_swizzle.h
Normal file
94
cutlass/gemm/hgemm_swizzle.h
Normal file
@ -0,0 +1,94 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in
|
||||
shared memory for multiplicands.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct HgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The input fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be half fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, half>::value), "Works on half");
|
||||
|
||||
/// The number of elements must be a multiple of 2.
|
||||
static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE HgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
// The indices to read two consecutive "rows".
|
||||
int const i0 = 2 * d + 0;
|
||||
int const i1 = 2 * d + 1;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
|
||||
int b0, b1;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
|
||||
// The indices to store with "strides".
|
||||
int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
|
||||
|
||||
dst_int[j0] = b0;
|
||||
dst_int[j1] = b1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
397
cutlass/gemm/hgemm_traits.h
Normal file
397
cutlass/gemm/hgemm_traits.h
Normal file
@ -0,0 +1,397 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of half-precision GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/hgemm_global_tile.h>
|
||||
#include <cutlass/gemm/hgemm_multiply_add.h>
|
||||
#include <cutlass/gemm/hgemm_swizzle.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 2>
|
||||
struct HgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
half,
|
||||
/// The scalar type for D.
|
||||
half,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, half, half, half>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
2,
|
||||
/// The number of scalars per STS for D.
|
||||
8,
|
||||
/// The number of scalars per LDS for D.
|
||||
2,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct HgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Convert<typename Iterator_::Fragment, typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef HgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size MxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K ) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as warps x 32(the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^T.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef HgemmCrosswiseGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
half const,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc)
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits <
|
||||
// The pointer.
|
||||
half,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
|
||||
GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
2,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer.
|
||||
half const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
8,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct HgemmTraitsHelper {
|
||||
/// The HGEMM config.
|
||||
typedef HgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef HgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef HgemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
|
||||
/// The functor to do the multiply-add in the main loop.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_> GemmEpilogueTraits;
|
||||
/// The epilogue.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<half>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 2,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = HgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
struct HgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
48
cutlass/gemm/identity_block_swizzle.h
Normal file
48
cutlass/gemm/identity_block_swizzle.h
Normal file
@ -0,0 +1,48 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
|
||||
|
||||
Currently, we only implement an identity mapping.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IdentityBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IdentityBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
320
cutlass/gemm/igemm_epilogue.h
Normal file
320
cutlass/gemm/igemm_epilogue.h
Normal file
@ -0,0 +1,320 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and
|
||||
floating-point output matrix formats.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmFloatToInt8Converter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<float, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<int8_t, kElements_> OutputFragment;
|
||||
|
||||
// We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
|
||||
// The outputs.
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Iterate over the floats and pack them together to produce ints.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the float4.
|
||||
float4 f4 = src_f4[i];
|
||||
|
||||
// Clamp the 4 elements of the floats to the [-128, +127] range.
|
||||
float x = fmaxf(-128.f, fminf(127.f, f4.x));
|
||||
float y = fmaxf(-128.f, fminf(127.f, f4.y));
|
||||
float z = fmaxf(-128.f, fminf(127.f, f4.z));
|
||||
float w = fmaxf(-128.f, fminf(127.f, f4.w));
|
||||
|
||||
// Convert to integers.
|
||||
int ix = (int)x;
|
||||
int iy = (int)y;
|
||||
int iz = (int)z;
|
||||
int iw = (int)w;
|
||||
|
||||
// Extract the lower bytes to build an int32 with 4 int8.
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
|
||||
|
||||
// Store the int.
|
||||
dst_int[i] = ix;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmGlobalStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
|
||||
typedef IgemmFloatToInt8Converter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmInt8ToFloatConverter {
|
||||
/// The input fragment.
|
||||
typedef Fragment<int8_t, kElements_> InputFragment;
|
||||
/// The output fragment.
|
||||
typedef Fragment<float, kElements_> OutputFragment;
|
||||
|
||||
// We are unpacking 4 int8s from int32.
|
||||
static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
transform(src, 0, dst);
|
||||
}
|
||||
|
||||
/// Transform a fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
|
||||
// The inputs.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
// The outputs.
|
||||
float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
|
||||
|
||||
// Iterate over the int8 and unpack them together to produce floats.
|
||||
for (int i = 0; i < kElements_ / 4; ++i) {
|
||||
// Read the int.
|
||||
int ix, iy, iz, iw = src_int[i];
|
||||
|
||||
// Extract the 4 bytes.
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
|
||||
asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
|
||||
|
||||
// The floats.
|
||||
float fx, fy, fz, fw;
|
||||
|
||||
// Convert to floats (make sure we generate I2F.F32.S8).
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
|
||||
asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
|
||||
|
||||
// Store the float4.
|
||||
dst_f4[i] = make_float4(fx, fy, fz, fw);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputFragment_, typename OutputScalar_>
|
||||
struct IgemmGlobalLoadTransformer {
|
||||
typedef Convert<InputFragment_, Fragment<OutputScalar_, InputFragment_::kElements> > Transformer;
|
||||
};
|
||||
|
||||
template <int kElements_>
|
||||
struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
|
||||
typedef IgemmInt8ToFloatConverter<kElements_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputScalar_, typename OutputFragment_>
|
||||
struct IgemmSharedStoreTransformer {
|
||||
typedef Convert<Fragment<InputScalar_, OutputFragment_::kElements>, OutputFragment_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
|
||||
struct IgemmEpilogueTraitsHelper
|
||||
: public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> Base;
|
||||
/// The config.
|
||||
typedef IgemmConfig_ IgemmConfig;
|
||||
|
||||
/// The scalar type of the epilogue.
|
||||
typedef typename Base::Scalar Scalar;
|
||||
/// The iterations.
|
||||
typedef typename Base::Iterations Iterations;
|
||||
/// The iterations strides.
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalLoadTileTraits GlobalLoadTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalLoadTileTraits> GlobalLoadIteratorC;
|
||||
/// The fragment that needs to be produced by the load iterator.
|
||||
typedef typename GlobalLoadIteratorC::Fragment GlobalFragmentC;
|
||||
/// The transformer from loaded data to math fragment.
|
||||
typedef
|
||||
typename IgemmGlobalLoadTransformer<GlobalFragmentC, Scalar>::Transformer GlobalTransformerC;
|
||||
|
||||
/// The traits class for the iterator.
|
||||
typedef typename Base::GlobalStoreTileTraits GlobalStoreTileTraits;
|
||||
/// The iterator to store to shared memory.
|
||||
typedef GemmGlobalIteratorCd<GlobalStoreTileTraits> GlobalStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename GlobalStoreIteratorD::Fragment GlobalFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef
|
||||
typename IgemmGlobalStoreTransformer<Scalar, GlobalFragmentD>::Transformer GlobalTransformerD;
|
||||
|
||||
/// The traits class for the shared iterator to store D to shared memory.
|
||||
typedef typename Base::SharedStoreTileTraits SharedStoreTileTraits;
|
||||
/// The shared iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal>
|
||||
SharedStoreIteratorD;
|
||||
/// The fragment that needs to be passed to that store iterator.
|
||||
typedef typename SharedStoreIteratorD::Fragment SharedStoreFragmentD;
|
||||
/// The transformer from accumulators to shared memory fragments.
|
||||
typedef typename IgemmSharedStoreTransformer<typename IgemmConfig::Accumulators::Element,
|
||||
SharedStoreFragmentD>::Transformer
|
||||
SharedStoreTransformerD;
|
||||
/// The traits class for the shared iterator to load D from shared memory.
|
||||
typedef typename Base::SharedLoadTileTraits SharedLoadTileTraits;
|
||||
/// The shared iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The config.
|
||||
typename IgemmConfig_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class to assemble the traits.
|
||||
typename Helper_ = IgemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct IgemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
// The output tile.
|
||||
typename IgemmConfig_::OutputTile,
|
||||
// The accumulators.
|
||||
typename IgemmConfig_::Accumulators,
|
||||
// The global iterator for C.
|
||||
typename Helper_::GlobalLoadIteratorC,
|
||||
// The transformer for C.
|
||||
typename Helper_::GlobalTransformerC,
|
||||
// The transformer for D.
|
||||
typename Helper_::GlobalTransformerD,
|
||||
// The global iterator for D.
|
||||
typename Helper_::GlobalStoreIteratorD,
|
||||
// The iterator to store D to shared memory.
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
typename Helper_::Delta,
|
||||
// The functor to be used in the epilogue.
|
||||
EpilogueFunctor_,
|
||||
// The index.
|
||||
Index_> {
|
||||
/// Do we output in int8?
|
||||
static bool const kInt8Output =
|
||||
platform::is_same<typename IgemmConfig_::ScalarC, int8_t>::value != 0;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
|
||||
struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// The base class.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
161
cutlass/gemm/igemm_global_tile.h
Normal file
161
cutlass/gemm/igemm_global_tile.h
Normal file
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements tile iterators to partition the thread block tile into 2D subtiles and
|
||||
efficiently load each. Applies permute transformation to construct 'interleaved K-strided'
|
||||
data layout in which 4-element dot products from the same K index are arranged in consecutive
|
||||
locations within shared memory.
|
||||
|
||||
Supports efficient loads from shared memory to target the DP4A instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
// Which GEMM operand?
|
||||
kOperand_,
|
||||
// The layout.
|
||||
kLayout_,
|
||||
// The scalar.
|
||||
Scalar_,
|
||||
// The tile.
|
||||
Tile_,
|
||||
// The threads.
|
||||
Threads_,
|
||||
// The number of scalars per LDG/STG.
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> Base;
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 4, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 4,
|
||||
4,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
|
||||
int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 4, Base::Tile::kC> ThreadsDelta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Deprecated. Please use IgemmGlobalTileTraits instead.
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmContiguousGlobalTileTraits
|
||||
: public IgemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> Base;
|
||||
/// The functor to compute the thread offset.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) {
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
// The residue.
|
||||
int const kResidue = (int)(bounds[1] % kBlock);
|
||||
|
||||
// Compute the number of elements that are valid.
|
||||
int const left = kResidue - Base::thread_offset[2];
|
||||
if (left > 0 && left < 4) {
|
||||
mask_ = (1u << (8 * left)) - 1u;
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::get(value, d, h, w, c);
|
||||
if (in_residue_) {
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(typename Base::Index k) {
|
||||
Base::move_to_residue(k);
|
||||
in_residue_ = true;
|
||||
}
|
||||
|
||||
/// Move back to the beginning of the first tile.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
Base::rollback();
|
||||
in_residue_ = false;
|
||||
}
|
||||
|
||||
/// Are we in the residue?
|
||||
bool in_residue_;
|
||||
/// The mask to clean up the values.
|
||||
uint32_t mask_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
89
cutlass/gemm/igemm_multiply_add.h
Normal file
89
cutlass/gemm/igemm_multiply_add.h
Normal file
@ -0,0 +1,89 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements matrix multiply accumulate operation of 8-bit integer data using DP4A
|
||||
instruction.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<4, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef int8_t ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW * 4> FragmentA;
|
||||
/// The type for B.
|
||||
typedef int8_t ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH * 4> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
// The inputs.
|
||||
int const* a_int = reinterpret_cast<int const*>(&a[0]);
|
||||
int const* b_int = reinterpret_cast<int const*>(&b[0]);
|
||||
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
|
||||
: "=r"(d[j * AccumulatorsPerThread::kW + i])
|
||||
: "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
115
cutlass/gemm/igemm_swizzle.h
Normal file
115
cutlass/gemm/igemm_swizzle.h
Normal file
@ -0,0 +1,115 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Transposes a fragment of data containing packed 8-bit integer elements.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GlobalIterator_>
|
||||
struct IgemmSwizzle {
|
||||
/// The global iterator.
|
||||
typedef GlobalIterator_ GlobalIterator;
|
||||
/// The source fragment.
|
||||
typedef typename GlobalIterator::Fragment Fragment;
|
||||
/// The shape of the source fragment.
|
||||
typedef typename GlobalIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// The source fragment.
|
||||
typedef Fragment InputFragment;
|
||||
/// The destination fragment.
|
||||
typedef Fragment OutputFragment;
|
||||
|
||||
/// The src/dst must be int8 fragments.
|
||||
static_assert((platform::is_same<typename Fragment::Element, int8_t>::value), "Works on int8");
|
||||
|
||||
/// The number of elements must be a multiple of 4.
|
||||
static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
|
||||
"Not multiple of 4");
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmSwizzle() {}
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
|
||||
// Expose src/dst as int arrays.
|
||||
int const* src_int = reinterpret_cast<int const*>(&src[0]);
|
||||
int* dst_int = reinterpret_cast<int*>(&dst[0]);
|
||||
|
||||
// Transpose the data.
|
||||
for (int d = 0; d < FragmentShape::kD; ++d) {
|
||||
for (int h = 0; h < FragmentShape::kH / 4; ++h) {
|
||||
for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
|
||||
int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
|
||||
(4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
|
||||
|
||||
int a0 = src_int[i0];
|
||||
int a1 = src_int[i1];
|
||||
int a2 = src_int[i2];
|
||||
int a3 = src_int[i3];
|
||||
|
||||
int b0, b1, b2, b3, c0;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
|
||||
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
|
||||
|
||||
dst_int[i0] = b0;
|
||||
dst_int[i1] = b1;
|
||||
dst_int[i2] = b2;
|
||||
dst_int[i3] = b3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
539
cutlass/gemm/igemm_traits.h
Normal file
539
cutlass/gemm/igemm_traits.h
Normal file
@ -0,0 +1,539 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed
|
||||
to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output
|
||||
formats vary.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/igemm_epilogue.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/gemm/igemm_multiply_add.h>
|
||||
#include <cutlass/gemm/igemm_swizzle.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
ScalarD_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
int8_t,
|
||||
/// The scalar type for D.
|
||||
int8_t,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
4,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
4,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^N.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerStsA = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kA,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsA,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for A^N.
|
||||
typedef GemmSharedLoadTileATraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef int8_t Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef int8_t MultiplyAddScalar;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size NxK in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreWithSkewTileAbTraits<
|
||||
// The pointer is int8.
|
||||
int8_t,
|
||||
// The tile has size KxN in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as (threads / K) x K (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS.
|
||||
kScalarsPerStsB,
|
||||
// The skew to avoid bank conflicts added in the tile W dimension.
|
||||
16>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for B^N.
|
||||
typedef GemmSharedLoadTileBTraits<
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The number of threads per warp.
|
||||
typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
|
||||
// The shape of the FMA instruction.
|
||||
typename GemmConfig_::InstructionShape,
|
||||
// The number of stages.
|
||||
GemmConfig_::kStages,
|
||||
// The number of scalars per LDS.
|
||||
16,
|
||||
// The skew.
|
||||
SharedStoreTileTraits::kSkew>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename Index_>
|
||||
struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerStsB = 16;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^T.
|
||||
typedef IgemmGlobalTileTraits<
|
||||
GemmOperand::kB,
|
||||
// The layout.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
int8_t const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer is float.
|
||||
int8_t,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerA {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
|
||||
struct IgemmTransformerB {};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
|
||||
typedef Copy<typename Iterator_::Fragment> Transformer;
|
||||
};
|
||||
|
||||
template <typename Iterator_>
|
||||
struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
|
||||
typedef IgemmSwizzle<Iterator_> Transformer;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct IgemmTraitsHelper {
|
||||
/// The IGEMM config.
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, AccumulatorsPerThread_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef IgemmTileTraitsHelperB<kLayoutB_, GemmConfig, Index_> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
|
||||
|
||||
/// The default transformer for A.
|
||||
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
|
||||
|
||||
// The default transformer for B.
|
||||
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA, Copy<typename SharedLoadIteratorA::Fragment> >
|
||||
SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB, Copy<typename SharedLoadIteratorB::Fragment> >
|
||||
SharedLoadStreamB;
|
||||
|
||||
/// The multiply-add functor.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The epilogue.
|
||||
typedef IgemmEpilogue<IgemmEpilogueTraits<GemmConfig, EpilogueFunctor_> > Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ScalarD_>
|
||||
struct IgemmEpilogueScalar {
|
||||
typedef float Scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct IgemmEpilogueScalar<int> {
|
||||
typedef int Scalar;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<32, 128, 128>,
|
||||
/// The output type.
|
||||
typename ScalarD_ = int,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<typename IgemmEpilogueScalar<ScalarD_>::Scalar>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = IgemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarD_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
Index_> >
|
||||
struct IgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,310 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Abstraction for coordinating inter-block k-splitting
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Storage and initialization
|
||||
******************************************************************************/
|
||||
|
||||
enum
|
||||
{
|
||||
NumFlagsSplitK = 4096
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Global K-split semaphore flags
|
||||
*
|
||||
* TODO: use demand-allocated storage to provide copies for concurrent streams
|
||||
*/
|
||||
__device__ int d_flags_split_k[NumFlagsSplitK];
|
||||
|
||||
|
||||
/**
|
||||
* Preparation kernel for zero-initializing semaphore flags
|
||||
*/
|
||||
__global__ void prepare_kernel(int *d_flags_split_k)
|
||||
{
|
||||
int tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (tid < NumFlagsSplitK)
|
||||
d_flags_split_k[tid] = 0;
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* k_split_control
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Abstraction for coordinating inter-block k-splitting
|
||||
*/
|
||||
struct k_split_control
|
||||
{
|
||||
/// Extent of a thread block's partition along the GEMM K-axis
|
||||
int split_k;
|
||||
|
||||
/// Whether or not to use a semaphore for inter-block k-splitting.
|
||||
bool use_semaphore;
|
||||
|
||||
/// Pointer to semaphore
|
||||
int *d_flags;
|
||||
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Device API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Return the thread block's starting coordinate (k) within the
|
||||
* multiplicand matrices
|
||||
*/
|
||||
inline __device__
|
||||
int block_begin_item_k()
|
||||
{
|
||||
return blockIdx.z * split_k;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return the thread block's ending coordinate (k) within the multiplicand
|
||||
* matrices (one-past)
|
||||
*/
|
||||
inline __device__
|
||||
int block_end_item_k(int dim_k)
|
||||
{
|
||||
int next_start_k = block_begin_item_k() + split_k;
|
||||
return __NV_STD_MIN(next_start_k, dim_k);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Whether the thread block is a secondary accumulator in an inter-block
|
||||
* k-splitting scheme
|
||||
*/
|
||||
inline __device__
|
||||
bool is_secondary_accumulator()
|
||||
{
|
||||
return (blockIdx.z > 0);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Wait for predecessor thread block(s) to produce the exclusive
|
||||
* partial-sums for this block-wide tile
|
||||
*/
|
||||
inline __device__
|
||||
void wait()
|
||||
{
|
||||
// Wait on semaphore
|
||||
if ((use_semaphore) && (blockIdx.z > 0))
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int bid = (blockIdx.y * gridDim.x) + blockIdx.x;
|
||||
int hash = bid % NumFlagsSplitK;
|
||||
int found;
|
||||
int looking = blockIdx.z;
|
||||
while (true)
|
||||
{
|
||||
asm volatile ("ld.global.cg.u32 %0, [%1];\n" : "=r"(found) : "l"(d_flags + hash));
|
||||
|
||||
if (found == looking)
|
||||
break;
|
||||
|
||||
/// Fence to keep load from being hoisted from the loop
|
||||
__syncwarp(0x00000001);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Signal the successor thread_block(s) that the inclusive partial-sums
|
||||
* from this block-wide tile are available
|
||||
*/
|
||||
inline __device__
|
||||
void signal()
|
||||
{
|
||||
if (use_semaphore)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int bid = (blockIdx.y * gridDim.x) + blockIdx.x;
|
||||
int hash = bid % NumFlagsSplitK;
|
||||
int val = blockIdx.z + 1;
|
||||
|
||||
asm volatile ("st.global.cg.u32 [%0], %1;\n" : : "l"(d_flags + hash), "r"(val));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Grid launch API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
inline
|
||||
k_split_control(
|
||||
int *d_flags,
|
||||
int sm_count,
|
||||
int max_sm_occupancy,
|
||||
int dim_k,
|
||||
int block_tile_items_k,
|
||||
dim3 block_dims,
|
||||
dim3 &grid_dims) ///< [in,out]
|
||||
:
|
||||
d_flags(d_flags),
|
||||
split_k(dim_k)
|
||||
{
|
||||
// Compute wave efficiency
|
||||
float wave_efficiency = get_wave_efficiency(
|
||||
sm_count,
|
||||
max_sm_occupancy,
|
||||
block_dims,
|
||||
grid_dims);
|
||||
|
||||
// Update split-k if wave efficiency is less than some threshold
|
||||
if (wave_efficiency < 0.9)
|
||||
{
|
||||
int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z;
|
||||
|
||||
// Ideal number of thread blocks in grid
|
||||
int ideal_threadblocks = lcm(sm_count, num_threadblocks);
|
||||
|
||||
// Desired number of partitions to split K-axis into
|
||||
int num_partitions = ideal_threadblocks / num_threadblocks;
|
||||
|
||||
// Compute new k-split share
|
||||
int new_split_k = (dim_k + num_partitions - 1) / num_partitions;
|
||||
|
||||
// Round split_k share to the nearest block_task_policy_t::BlockItemsK
|
||||
new_split_k = round_nearest(new_split_k, block_tile_items_k);
|
||||
|
||||
// Recompute k-splitting factor with new_split_k
|
||||
num_partitions = (dim_k + new_split_k - 1) / new_split_k;
|
||||
|
||||
// Update grid dims and k if we meet the minimum number of iterations worth the overhead of splitting
|
||||
int min_iterations_k = 8;
|
||||
|
||||
if (((new_split_k / block_tile_items_k) > min_iterations_k) && // We're going to go through at least this many k iterations
|
||||
(sm_count * max_sm_occupancy < NumFlagsSplitK)) // We have enough semaphore flags allocated
|
||||
{
|
||||
grid_dims.z = num_partitions;
|
||||
split_k = new_split_k;
|
||||
}
|
||||
}
|
||||
|
||||
use_semaphore = (grid_dims.z > 1);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initializer
|
||||
*/
|
||||
cudaError_t prepare(
|
||||
cudaStream_t stream, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous) ///< Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console if DEBUG is defined. Default is \p false.
|
||||
|
||||
{
|
||||
cudaError error = cudaSuccess;
|
||||
|
||||
if (use_semaphore)
|
||||
{
|
||||
int block_threads = 128;
|
||||
int grid_dims = (NumFlagsSplitK + block_threads - 1) / block_threads;
|
||||
|
||||
prepare_kernel<<<grid_dims, block_threads, 0, stream>>>(d_flags);
|
||||
|
||||
// Check for failure to launch
|
||||
if (CUDA_PERROR_DEBUG(error = cudaPeekAtLastError()))
|
||||
return error;
|
||||
|
||||
// Sync the stream if specified to flush runtime errors
|
||||
if (debug_synchronous && (CUDA_PERROR_DEBUG(error = cudaStreamSynchronize(stream))))
|
||||
return error;
|
||||
}
|
||||
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the efficiency of dispatch wave quantization
|
||||
*/
|
||||
float get_wave_efficiency(
|
||||
int sm_count,
|
||||
int max_sm_occupancy,
|
||||
dim3 block_dims,
|
||||
dim3 grid_dims)
|
||||
{
|
||||
// Heuristic for how many warps are needed to saturate an SM for a given
|
||||
// multiply-accumulate genre. (NB: We could make this more rigorous by
|
||||
// specializing on data types and SM width)
|
||||
int saturating_warps_per_sm = 16;
|
||||
|
||||
int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z;
|
||||
int threads_per_threadblock = block_dims.x * block_dims.y;
|
||||
int warps_per_threadblock = threads_per_threadblock / 32;
|
||||
int saturating_threadblocks_per_sm = (saturating_warps_per_sm + warps_per_threadblock - 1) / warps_per_threadblock;
|
||||
|
||||
int saturating_residency = sm_count * saturating_threadblocks_per_sm;
|
||||
int full_waves = num_threadblocks / saturating_residency;
|
||||
int remainder_threadblocks = num_threadblocks % saturating_residency;
|
||||
int total_waves = (remainder_threadblocks == 0) ? full_waves : full_waves + 1;
|
||||
|
||||
float last_wave_saturating_efficiency = float(remainder_threadblocks) / saturating_residency;
|
||||
|
||||
return (float(full_waves) + last_wave_saturating_efficiency) / total_waves;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
85
cutlass/gemm/linear_scaling.h
Normal file
85
cutlass/gemm/linear_scaling.h
Normal file
@ -0,0 +1,85 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_> >
|
||||
struct LinearScaling {
|
||||
// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The adapater.
|
||||
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
|
||||
|
||||
/// The parameters.
|
||||
struct Params {
|
||||
/// The alpha/beta scaling params.
|
||||
Scalar alpha, beta;
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
alpha = desc.alpha;
|
||||
beta = desc.beta;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
mad.multiply(alpha, accum, output);
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
FragmentB_ tmp;
|
||||
mad.multiply(beta, old, tmp);
|
||||
mad.multiply_add(alpha, accum, tmp, output);
|
||||
}
|
||||
|
||||
/// The alpha/beta scaling factors.
|
||||
Scalar alpha, beta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
127
cutlass/gemm/sgemm_traits.h
Normal file
127
cutlass/gemm/sgemm_traits.h
Normal file
@ -0,0 +1,127 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of single-precision GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct SgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
float,
|
||||
/// The scalar type for B.
|
||||
float,
|
||||
/// The scalar type for C.
|
||||
float,
|
||||
/// The scalar type for D.
|
||||
float,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, float, float, float>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,469 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Thread-level multiply-accumulate abstraction
|
||||
*/
|
||||
|
||||
#include "../util/util.h"
|
||||
#include "dp_accummulate.h"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* thread_accumulator (generic specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Thread-level multiply-accumulate abstraction (generic specialization)
|
||||
*
|
||||
* The thread_accumulator class maintains a MxN tile of accumulators in
|
||||
* registers to which MxNxK matrix products of two thread tiles A (MxK)
|
||||
* and B (KxN) can be added, where:
|
||||
* M = ThreadItemsY
|
||||
* N = ThreadItemsX
|
||||
* K = sizeof(dp_vector_t) / sizeof(value_t).
|
||||
*
|
||||
* In order to leverage architecture-specific "dot-product accumulate" ISA
|
||||
* operations, K is dictated by the thread_accumulator class in the form of
|
||||
* the member-type dp_vector_t, which defines a K-component vector of value_t.
|
||||
* The multiplicand inputs A and B are provided as arrays of dp_vector_t having
|
||||
* extents ThreadItemsY and ThreadItemsX, respectively. (In the single
|
||||
* component "dp1" scenario where dp_vector_t == value_t and thus K == 1, the
|
||||
* multiplication is simply the outer product of two vectors.)
|
||||
*
|
||||
* The accumulators are zero-initialized in a two-phase process (construction +
|
||||
* initialization) that requires shared storage in the form of the member-type
|
||||
* scratch_storage_t during construction. (A single scratch_storage_t instance
|
||||
* can be uniformly referenced across all threads in the block during
|
||||
* construction *if* the block is synchronized between construction and
|
||||
* initialization.)
|
||||
*
|
||||
* NB: This generic class is not directly constructible. Architecture- and
|
||||
* algorithm-specific template specializations will provide the API
|
||||
* functionality prescribed here.
|
||||
*/
|
||||
template <
|
||||
int ThreadItemsY, ///< Height of thread tile in accum_t
|
||||
int ThreadItemsX, ///< Width of thread tile in accum_t
|
||||
typename value_t, ///< Multiplicand value type
|
||||
typename accum_t, ///< Accumulator value type
|
||||
int ACCUM_BYTES = ///< Size in bytes of accum_t
|
||||
sizeof(accum_t),
|
||||
arch_family_t::kind_t ArchFamily = ///< Architectural family enumerant
|
||||
CUTLASS_ARCH_FAMILY>
|
||||
struct thread_accumulator
|
||||
{
|
||||
protected:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Specialized dot-product traits type
|
||||
typedef dp_accummulate<value_t, accum_t> dp_accum_traits_t;
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Member types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Dot-product vector type
|
||||
typedef typename dp_accum_traits_t::dp_vector_t dp_vector_t;
|
||||
|
||||
/// Scratch storage layout
|
||||
struct scratch_storage_t {};
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Data members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread's tile of accumulators
|
||||
accum_t accumulators[ThreadItemsY][ThreadItemsX];
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Utility methods
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Compute a multiply-add at accumulator coordinates (x, y)
|
||||
*/
|
||||
inline __device__
|
||||
void mad_xy(
|
||||
dp_vector_t (&tile_a)[ThreadItemsY],
|
||||
dp_vector_t (&tile_b)[ThreadItemsX],
|
||||
int x,
|
||||
int y)
|
||||
{
|
||||
dp_accum_traits_t::mad(
|
||||
accumulators[y][x],
|
||||
tile_a[y],
|
||||
tile_b[x],
|
||||
accumulators[y][x]);
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
thread_accumulator(
|
||||
scratch_storage_t &scratch)
|
||||
{}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Accumulator API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* \brief Zero-initialize thread accumulators.
|
||||
*
|
||||
* If a common reference to a single block-wide shared instance of scratch_storage_t
|
||||
* is used during construction, the block must be synchronized after construction
|
||||
* but prior to the invocation of init().
|
||||
*/
|
||||
inline __device__
|
||||
void init()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < ThreadItemsY; ++y) {
|
||||
#pragma unroll
|
||||
for (int x = 0; x < ThreadItemsX; ++x)
|
||||
{
|
||||
accumulators[y][x] = accum_t(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Retrieve the accumulator at thread tile coordinates (x, y)
|
||||
*/
|
||||
inline __device__
|
||||
accum_t get(int x, int y)
|
||||
{
|
||||
// Accumulators are row-major
|
||||
return accumulators[y][x];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Compute the product of tile_a and tile_b and add the result to
|
||||
* the tile of accumulators.
|
||||
*/
|
||||
inline __device__
|
||||
void multiply_accumulate(
|
||||
dp_vector_t (&tile_a)[ThreadItemsY],
|
||||
dp_vector_t (&tile_b)[ThreadItemsX])
|
||||
{
|
||||
// Simply traverse the accumulator tile in row-major order
|
||||
#pragma unroll
|
||||
for (int y = 0; y < ThreadItemsY; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int x = 0; x < ThreadItemsX; ++x)
|
||||
{
|
||||
mad_xy(tile_a, tile_b, x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* thread_accumulator (__half->__half specialization)
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Thread-level multiply-accumulate abstraction (__half->__half specialization)
|
||||
*
|
||||
* NB: Because we use the 2-item SIMD instruction HFMA2:
|
||||
* - ThreadItemsX must be an even multiple of 2
|
||||
* - ThreadItemsY must be an even multiple of 2
|
||||
*
|
||||
*/
|
||||
template <
|
||||
int ThreadItemsY, ///< Height in rows of thread tile in C
|
||||
int ThreadItemsX, ///< Width in columns of thread tile in C
|
||||
arch_family_t::kind_t ArchFamily> ///< Architectural family enumerant
|
||||
struct thread_accumulator<
|
||||
ThreadItemsY,
|
||||
ThreadItemsX,
|
||||
__half, ///< Multiplicand value type (matrices A and B)
|
||||
__half, ///< Accumulator value type (matrix C and scalars)
|
||||
2, ///< Size in bytes of accum_t
|
||||
ArchFamily>
|
||||
{
|
||||
protected:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constants
|
||||
enum
|
||||
{
|
||||
/// Height of thread tile in column-major uint32_t SIMD pairs along Y dimension
|
||||
ThreadTilePairsY = divide_assert<ThreadItemsY, 2>::value,
|
||||
|
||||
/// Width of thread tile in column-major uint32_t SIMD pairs along X dimension
|
||||
ThreadTilePairsX = ThreadItemsX,
|
||||
|
||||
/// Number of SIMD pairs in thread's slice of block-wide tile multiplicand A
|
||||
ThreadPairsA = divide_assert<ThreadItemsY, 2>::value,
|
||||
|
||||
/// Number of SIMD pairs in thread's slice of block-wide tile multiplicand B
|
||||
ThreadPairsB = divide_assert<ThreadItemsX, 2>::value,
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Member types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Dot-product vector type
|
||||
typedef __half dp_vector_t;
|
||||
|
||||
/// Scratch storage layout
|
||||
struct scratch_storage_t {};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread's tile of C accumulator pairs (the uint32_t SIMD pairs are
|
||||
/// column-major, the 2D tile layout is also column-major)
|
||||
uint32_t accumulator_pairs[ThreadTilePairsX][ThreadTilePairsY];
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Utility methods
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Compute an HFMA2 MAD
|
||||
*/
|
||||
inline __device__ void mad(
|
||||
uint32_t &d,
|
||||
const uint32_t &a,
|
||||
const uint32_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
|
||||
asm volatile ("fma.rn.f16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute an HFMA2 MAD with replicated b.lo:
|
||||
* d{hi} = a{hi} * b{lo} + c{hi};
|
||||
* d{lo} = a{lo} * b{lo} + c{lo};
|
||||
*/
|
||||
inline __device__ void mad_replicate_low(
|
||||
uint32_t &d,
|
||||
const uint32_t &a,
|
||||
const uint32_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
// Replicate low halves of b
|
||||
uint32_t replicate;
|
||||
asm volatile (
|
||||
"{"
|
||||
" .reg .b16 b_low,b_high;\n"
|
||||
" mov.b32 {b_low,b_high}, %1;\n"
|
||||
" mov.b32 %0, {b_low,b_low};\n"
|
||||
"}" : "=r"(replicate) : "r"(b));
|
||||
|
||||
mad(d, a, replicate, c);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute an HFMA2 MAD with replicated b.hi:
|
||||
* d{hi} = a{hi} * b{hi} + c{hi};
|
||||
* d{lo} = a{lo} * b{hi} + c{lo};
|
||||
*/
|
||||
inline __device__ void mad_replicate_high(
|
||||
uint32_t &d,
|
||||
const uint32_t &a,
|
||||
const uint32_t &b,
|
||||
const uint32_t &c)
|
||||
{
|
||||
// Replicate high halves of b
|
||||
uint32_t replicate;
|
||||
asm volatile (
|
||||
"{"
|
||||
" .reg .b16 b_low,b_high;\n"
|
||||
" mov.b32 {b_low,b_high}, %1;\n"
|
||||
" mov.b32 %0, {b_high,b_high};\n"
|
||||
"}" : "=r"(replicate) : "r"(b));
|
||||
|
||||
mad(d, a, replicate, c);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute a multiply-add at accumulator SIMD-pair coordinates (pair_x, pair_y)
|
||||
*/
|
||||
inline __device__
|
||||
void mad_xy_even(
|
||||
uint32_t (&pairs_tile_a)[ThreadPairsA],
|
||||
uint32_t (&pairs_tile_b)[ThreadPairsB],
|
||||
int pair_x,
|
||||
int pair_y)
|
||||
{
|
||||
// Even column: use low half of the b pair
|
||||
mad_replicate_low(
|
||||
accumulator_pairs[pair_x][pair_y],
|
||||
pairs_tile_a[pair_y],
|
||||
pairs_tile_b[pair_x / 2],
|
||||
accumulator_pairs[pair_x][pair_y]);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute a multiply-add at accumulator SIMD-pair coordinates (pair_x, pair_y)
|
||||
*/
|
||||
inline __device__
|
||||
void mad_xy_odd(
|
||||
uint32_t (&pairs_tile_a)[ThreadPairsA],
|
||||
uint32_t (&pairs_tile_b)[ThreadPairsB],
|
||||
int pair_x,
|
||||
int pair_y)
|
||||
{
|
||||
// Odd column: use high half of the b pair
|
||||
mad_replicate_high(
|
||||
accumulator_pairs[pair_x][pair_y],
|
||||
pairs_tile_a[pair_y],
|
||||
pairs_tile_b[pair_x / 2],
|
||||
accumulator_pairs[pair_x][pair_y]);
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor
|
||||
inline __device__
|
||||
thread_accumulator(
|
||||
scratch_storage_t &scratch)
|
||||
{}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Accumulator API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Zero-initialize thread accumulators.
|
||||
*/
|
||||
inline __device__
|
||||
void init()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < ThreadTilePairsY; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int x = 0; x < ThreadTilePairsX; ++x)
|
||||
{
|
||||
accumulator_pairs[x][y] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Retrieve the accumulator at thread tile coordinates (x, y)
|
||||
*/
|
||||
inline __device__
|
||||
__half get(int x, int y)
|
||||
{
|
||||
// SIMD pairs are column-major
|
||||
uint32_t pair = accumulator_pairs[x][y / 2];
|
||||
|
||||
return reinterpret_cast<__half (&)[2]>(pair)[y % 2];
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Compute the product of pairs_tile_a and pairs_tile_b and add the result to
|
||||
* the tile of accumulators.
|
||||
*/
|
||||
inline __device__
|
||||
void multiply_accumulate(
|
||||
dp_vector_t (&tile_a)[ThreadItemsY],
|
||||
dp_vector_t (&tile_b)[ThreadItemsX])
|
||||
{
|
||||
typedef uint32_t pairs_tile_a_t[ThreadPairsA];
|
||||
typedef uint32_t pairs_tile_b_t[ThreadPairsB];
|
||||
|
||||
// Alias slices in pairs
|
||||
pairs_tile_a_t &pairs_tile_a = reinterpret_cast<pairs_tile_a_t&>(tile_a);
|
||||
pairs_tile_b_t &pairs_tile_b = reinterpret_cast<pairs_tile_b_t&>(tile_b);
|
||||
|
||||
// Simply traverse the accumulator tile in column-major order
|
||||
#pragma unroll
|
||||
for (int x = 0; x < ThreadTilePairsX; ++x)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < ThreadTilePairsY; ++y)
|
||||
{
|
||||
mad_xy_even(pairs_tile_a, pairs_tile_b, x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
84
cutlass/gemm/thread_multiply_add.h
Normal file
84
cutlass/gemm/thread_multiply_add.h
Normal file
@ -0,0 +1,84 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template implementing matrix multiply-add operations on fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename ScalarA_,
|
||||
typename ScalarB_,
|
||||
typename ScalarC_>
|
||||
struct ThreadMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -1,215 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Thread-level multiply-accumulate abstraction
|
||||
* (Volta 4B accum_t specialization)
|
||||
*/
|
||||
|
||||
#include <mma.h>
|
||||
|
||||
#include "../util/util.h"
|
||||
#include "dp_accummulate.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
|
||||
/*!
|
||||
*\brief matrix_layout to perform conversion between Cutlass types and WMMA types
|
||||
*/
|
||||
template <matrix_transform_t::kind_t>
|
||||
struct matrix_layout;
|
||||
|
||||
/// Maps matrix_transform_t::NonTranspose to nvcuda::wmma::mem_col_major
|
||||
template <>
|
||||
struct matrix_layout<matrix_transform_t::NonTranspose>
|
||||
{
|
||||
/// Type tag in nvcuda::wmma namespace
|
||||
typedef nvcuda::wmma::col_major tag;
|
||||
|
||||
/// Column major layout
|
||||
static const nvcuda::wmma::layout_t kind = nvcuda::wmma::mem_col_major;
|
||||
|
||||
/// Cutlass matrix transform kind
|
||||
static const matrix_transform_t::kind_t cutlass_kind = matrix_transform_t::NonTranspose;
|
||||
};
|
||||
|
||||
/// Maps matrix_transform_t::NonTranspose to nvcuda::wmma::mem_row_major
|
||||
template <>
|
||||
struct matrix_layout<matrix_transform_t::Transpose>
|
||||
{
|
||||
/// Type tag in nvcuda::wmma namespace
|
||||
typedef nvcuda::wmma::row_major tag;
|
||||
|
||||
/// Column major layout
|
||||
static const nvcuda::wmma::layout_t kind = nvcuda::wmma::mem_row_major;
|
||||
|
||||
/// Cutlass matrix transform kind
|
||||
static const matrix_transform_t::kind_t cutlass_kind = matrix_transform_t::Transpose;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Warp-synchronous matrix multiply-accumulate abstraction
|
||||
*
|
||||
* wmma_accumulator maps the CUDA WMMA API onto the GEMM structure
|
||||
*/
|
||||
template <
|
||||
int WarpItemsY, /// Number of rows of the warp's accumulator tile
|
||||
int WarpItemsX, /// Number of columns of the warp's accumulator tile
|
||||
int WmmaItemsY, /// Number of rows in a single WMMA operation
|
||||
int WmmaItemsX, /// Number of columns in a single WMMA operation
|
||||
int WmmaItemsK, /// Inner dimension of WMMA operation
|
||||
typename value_a_t, /// Type of A operand
|
||||
typename value_b_t, /// Type of B operand
|
||||
typename accum_t, /// Type of source and destination accumulators
|
||||
matrix_transform_t::kind_t TransformA, /// Layout of A operand
|
||||
matrix_transform_t::kind_t TransformB /// Layout of B operand
|
||||
>
|
||||
struct wmma_accumulator
|
||||
{
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constants and types
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
enum
|
||||
{
|
||||
/// Number of WMMA blocks in warp row
|
||||
WmmaBlocksX = divide_assert<WarpItemsX, WmmaItemsX>::value,
|
||||
|
||||
/// Number of WMMA blocks in a warp column
|
||||
WmmaBlocksY = divide_assert<WarpItemsY, WmmaItemsY>::value,
|
||||
};
|
||||
|
||||
/// Fragment type for matrix operand A
|
||||
typedef nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::matrix_a,
|
||||
WmmaItemsY,
|
||||
WmmaItemsX,
|
||||
WmmaItemsK,
|
||||
value_a_t,
|
||||
typename matrix_layout<TransformA>::tag>
|
||||
fragment_a_t;
|
||||
|
||||
/// Fragment type for matrix operand B
|
||||
typedef nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::matrix_b,
|
||||
WmmaItemsY,
|
||||
WmmaItemsX,
|
||||
WmmaItemsK,
|
||||
value_b_t,
|
||||
typename matrix_layout<TransformB>::tag>
|
||||
fragment_b_t;
|
||||
|
||||
/// Fragment type for accumulator
|
||||
typedef nvcuda::wmma::fragment<
|
||||
nvcuda::wmma::accumulator,
|
||||
WmmaItemsY,
|
||||
WmmaItemsX,
|
||||
WmmaItemsK,
|
||||
accum_t>
|
||||
accumulator_t;
|
||||
|
||||
/// Scratch storage layout
|
||||
struct scratch_storage_t
|
||||
{
|
||||
/// Initialization vector
|
||||
uint4 zero_slab;
|
||||
};
|
||||
|
||||
public:
|
||||
//-------------------------------------------------------------------------
|
||||
// Data members
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Thread's tile of accumulators
|
||||
accumulator_t accumulators[WmmaBlocksX][WmmaBlocksY];
|
||||
|
||||
public:
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Constructor API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Constructor initializes accumulators to zero
|
||||
inline __device__
|
||||
wmma_accumulator()
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Accumulator API
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* \brief Zero-initialize thread accumulators.
|
||||
*/
|
||||
inline __device__
|
||||
void init()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int x = 0; x < WmmaBlocksX; ++x)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < WmmaBlocksY; ++y)
|
||||
{
|
||||
nvcuda::wmma::fill_fragment(accumulators[x][y], accum_t(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Compute the product of tile_a and tile_b and add the result to
|
||||
* the tile of accumulators.
|
||||
*/
|
||||
inline __device__
|
||||
void multiply_accumulate(
|
||||
fragment_a_t (&tile_a)[WmmaBlocksY],
|
||||
fragment_b_t (&tile_b)[WmmaBlocksX])
|
||||
{
|
||||
#pragma unroll
|
||||
for (int x = 0; x < WmmaBlocksX; ++x)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int y = 0; y < WmmaBlocksY; ++y)
|
||||
{
|
||||
nvcuda::wmma::mma_sync(accumulators[x][y], tile_a[y], tile_b[x], accumulators[x][y]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
161
cutlass/gemm/wmma_gemm_epilogue_traits.h
Normal file
161
cutlass/gemm/wmma_gemm_epilogue_traits.h
Normal file
@ -0,0 +1,161 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structural properties of WMMA GEMM's epilogue phase.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_shared_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_, typename EpilogueFunctor_, typename Index_ = int>
|
||||
struct WmmaGemmEpilogueTraitsHelper {
|
||||
/// The scalar.
|
||||
typedef typename EpilogueFunctor_::Scalar Scalar;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig_::OutputTile OutputTile;
|
||||
|
||||
/// The number of WMMAs in the H dimension.
|
||||
static int const kWmmasPerH =
|
||||
GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
|
||||
/// The number of iterations in the epilogue. That's the number of "horizontal" WMMAs.
|
||||
typedef Shape<1, 1, kWmmasPerH> Iterations;
|
||||
// The iteration strides in the H/W dimension.
|
||||
typedef Shape<0, 0, 0> Delta;
|
||||
/// The functor to do the math in the epilogue.
|
||||
typedef EpilogueFunctor_ Functor;
|
||||
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef WmmaGemmSharedStoreTileDTraits<
|
||||
// The output layout.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The iterator to store D to shared memory.
|
||||
typedef TileStoreIterator<SharedStoreTileTraits,
|
||||
typename SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
|
||||
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef WmmaGemmSharedLoadTileDTraits<
|
||||
// The pointer.
|
||||
typename Functor::Scalar,
|
||||
// The tile size.
|
||||
typename SharedStoreIteratorD::Tile,
|
||||
// The number of threads.
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDS.
|
||||
GemmConfig_::kScalarsPerLdsD>
|
||||
SharedLoadTileTraits;
|
||||
|
||||
/// The iterator to load D from shared memory.
|
||||
typedef TileLoadIterator<SharedLoadTileTraits,
|
||||
typename SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float const.
|
||||
typename GemmConfig_::ScalarC const,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgC>
|
||||
GlobalLoadTileTraits;
|
||||
|
||||
/// The iterator to load C.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
|
||||
/// The transformer for C.
|
||||
typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
|
||||
|
||||
/// The traits class to build the iterator to store data to global memory for D^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float.
|
||||
typename GemmConfig_::ScalarD,
|
||||
// The tile has size (N / Iterations)xM in GEMM's terminology.
|
||||
Shape<1,
|
||||
GemmConfig_::OutputTile::kH / ShapeCount<Iterations>::kCount,
|
||||
GemmConfig_::OutputTile::kW>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerStgD>
|
||||
GlobalStoreTileTraits;
|
||||
|
||||
/// The iterator to store D.
|
||||
typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
|
||||
/// The transformer for D.
|
||||
typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
211
cutlass/gemm/wmma_gemm_global_tile.h
Normal file
211
cutlass/gemm/wmma_gemm_global_tile.h
Normal file
@ -0,0 +1,211 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines tile iterator traits for loading thread block-level tile from global memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
|
||||
struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_> {
|
||||
/// The base class.
|
||||
typedef GemmGlobalTileTraits<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Scalar_,
|
||||
Tile_,
|
||||
Threads_,
|
||||
kAccessSize_>
|
||||
Base;
|
||||
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> Delta;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int thread_offset_h = threadIdx.x / Base::Threads::kW;
|
||||
int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
|
||||
|
||||
return make_Coord(0, thread_offset_h, thread_offset_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef WmmaGemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The traits.
|
||||
typedef TileTraits_ Traits;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<Traits,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
|
||||
/// The scalar.
|
||||
typedef typename TileTraits_::Scalar Scalar;
|
||||
/// The pointer.
|
||||
typedef typename TileTraits_::Pointer Pointer;
|
||||
/// The threads.
|
||||
typedef typename TileTraits_::Threads Threads;
|
||||
/// The index.
|
||||
typedef Index_ Index;
|
||||
/// The thread offset functor.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_h, inc_advance;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
/// The strides to increment the predicate offset.
|
||||
Index predicate_inc_h, predicate_inc_advance;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
stride_h = ld;
|
||||
// Each thread output 1 column per iteration. .
|
||||
inc_h = ld * TileTraits_::Threads::kH;
|
||||
inc_advance = inc_h + epilogue_stride_w;
|
||||
|
||||
predicate_offset = n;
|
||||
predicate_inc_h = TileTraits_::Threads::kH;
|
||||
predicate_inc_advance = predicate_inc_h + epilogue_delta_w;
|
||||
|
||||
// It worked.
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int const pointer_offset = 0,
|
||||
int const pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
|
||||
: params(params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + pointer_offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Test the predicate.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
108
cutlass/gemm/wmma_gemm_multiply_add.h
Normal file
108
cutlass/gemm/wmma_gemm_multiply_add.h
Normal file
@ -0,0 +1,108 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements warp-level matrix multiply-accumulate operation using CUDA WMMA API.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
#include <cutlass/fragment.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayoutA_,
|
||||
typename ScalarA_,
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
typename ScalarB_,
|
||||
MatrixLayout::Kind kLayoutC_,
|
||||
typename ScalarC_,
|
||||
typename AccumulatorsPerWarp_,
|
||||
typename InstructionShape_>
|
||||
struct WmmaGemmMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef InstructionShape_ InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
|
||||
/// The dimensions.
|
||||
typedef AccumulatorsPerWarp_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
240
cutlass/gemm/wmma_gemm_shared_tile.h
Normal file
240
cutlass/gemm/wmma_gemm_shared_tile.h
Normal file
@ -0,0 +1,240 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterator traits for efficiently loading and storing fragment to and from shared
|
||||
memory, specialized for WMMA GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template <class>
|
||||
struct Debug {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileATraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kA;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp % Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Warps_,
|
||||
int kWarpStride_,
|
||||
typename Iterations_,
|
||||
typename Delta_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaGemmSharedLoadTileBTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kB;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The tile with skew.
|
||||
typedef Tile_ Tile;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The warps strides.
|
||||
static int const kWarpStride = kWarpStride_;
|
||||
/// The number of iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ Delta;
|
||||
/// The strides between iterations.
|
||||
typedef Delta_ ImmediateOffsetStrides;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The offset.
|
||||
int const offset = warp / Warps::kW * kWarpStride;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename OutputTile_,
|
||||
typename Warps_,
|
||||
typename WmmaShape_,
|
||||
int kSkew_ = 0>
|
||||
struct WmmaGemmSharedStoreTileDTraits {
|
||||
/// The operand.
|
||||
static GemmOperand::Kind const kOperand = GemmOperand::kC;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The access size
|
||||
static int const kAccessSize = 1;
|
||||
/// The pointer.
|
||||
typedef Scalar* Pointer;
|
||||
/// The number of warps.
|
||||
typedef Warps_ Warps;
|
||||
/// The shape of the WMMA instruction.
|
||||
typedef WmmaShape_ WmmaShape;
|
||||
/// The skew.
|
||||
static int const kSkew = kSkew_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
/// The tile with skew.
|
||||
typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
|
||||
/// The number of iterations needed to store the tile.
|
||||
typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The warp id.
|
||||
int const warp = threadIdx.x / kWarpSize;
|
||||
// The starting column.
|
||||
int const h = warp / Warps::kW * WmmaShape::kH;
|
||||
// The w.
|
||||
int const w = warp % Warps::kW * WmmaShape::kW;
|
||||
// The offset.
|
||||
int const offset = h * Tile::kW + w;
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_>
|
||||
struct WmmaGemmSharedLoadTileDTraits {
|
||||
/// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
/// The pointer.
|
||||
typedef Scalar const* Pointer;
|
||||
/// The access size
|
||||
static int const kAccessSize = kScalarsPerLds_;
|
||||
/// The tile.
|
||||
typedef typename ReshapeTile<Tile_, kScalarsPerLds_>::Tile Tile;
|
||||
/// The threads.
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_>
|
||||
ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
|
||||
Iterations;
|
||||
|
||||
/// ThreadOffset
|
||||
struct ThreadOffset {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
// The offset.
|
||||
int const offset = ComputeThreadOffsetFromStrides<Threads, ThreadsStrides>::get();
|
||||
return make_Coord(0, 0, offset, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
574
cutlass/gemm/wmma_gemm_traits.h
Normal file
574
cutlass/gemm/wmma_gemm_traits.h
Normal file
@ -0,0 +1,574 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of GEMM targeting WMMA API in CUDA.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_multiply_add.h>
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_>
|
||||
struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarC_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
WmmaGemmMultiplyAdd<kLayoutA_,
|
||||
half,
|
||||
kLayoutB_,
|
||||
half,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of stages in shared memory.
|
||||
1> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
|
||||
/// The shared tile size.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD,
|
||||
GemmConfig_::OutputTile::kW + kSkew>
|
||||
Tile;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kColumnMajor,
|
||||
typename Base::MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename Base::GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The number of scalars loaded per iteration.
|
||||
static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The output tile size.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kW,
|
||||
GemmConfig_::OutputTile::kD + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(typename Base::MultiplyAddScalar);
|
||||
/// The shared tile size.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kD,
|
||||
GemmConfig_::OutputTile::kH + kSkew>
|
||||
Tile;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kRowMajor,
|
||||
typename Base::MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^T.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename Base::GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The number of scalars loaded per iteration.
|
||||
static int const kScalarsPerIteration = Tile::kW * GemmConfig_::InstructionShape::kD;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
typename Base::MultiplyAddScalar,
|
||||
// The output tile size.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<kScalarsPerIteration, 0, kScalarsPerW, 0>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
MultiplyAddScalar,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// A is row-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kH,
|
||||
GemmConfig_::OutputTile::kD + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The index.
|
||||
typename Index_>
|
||||
struct WmmaGemmTraitsHelper {
|
||||
/// The WMMA GEMM config.
|
||||
typedef WmmaGemmConfig<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
|
||||
/// The GEMM config for A.
|
||||
typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorA;
|
||||
/// The default transformer for A.
|
||||
typedef Copy<typename GlobalLoadIteratorA::Fragment> GlobalTransformerA;
|
||||
/// The iterator to store A to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperB::GlobalTileTraits, Index_>
|
||||
GlobalLoadIteratorB;
|
||||
// The default transformer for B.
|
||||
typedef Copy<typename GlobalLoadIteratorB::Fragment> GlobalTransformerB;
|
||||
/// The iterator to store B to shared memory.
|
||||
typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperA::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
/// The iterator to load B from shared memory.
|
||||
typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
|
||||
typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperB::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
|
||||
/// The functor to do the multiply-add in the main loop.
|
||||
typedef typename GemmConfig::MultiplyAdd MultiplyAdd;
|
||||
/// The object to clear accumulators.
|
||||
typedef ClearAccumulators<typename MultiplyAdd::ScalarC> ClearAccumulators;
|
||||
|
||||
/// The helper to create the epilogue traits.
|
||||
typedef WmmaGemmEpilogueTraitsHelper<GemmConfig, EpilogueFunctor_, Index_> EpilogueTraitsHelper;
|
||||
/// The traits class for the epilogue.
|
||||
typedef SimplifiedGemmEpilogueTraits<GemmConfig, EpilogueFunctor_, Index_, EpilogueTraitsHelper>
|
||||
GemmEpilogueTraits;
|
||||
/// The epilogue.
|
||||
typedef GemmEpilogue<GemmEpilogueTraits> Epilogue;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename DefaultShape_ = Shape<64, 32, 64> >
|
||||
struct WmmaGemmAccumulatorsPerWarp {
|
||||
typedef typename ShapeMin<OutputTile_, DefaultShape_>::Shape Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_ = Shape<64, 128, 128>,
|
||||
/// The output type.
|
||||
typename ScalarC_ = float,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_ = ScalarC_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_ = Shape<16, 16, 16>,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 8,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 8,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerWarp_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
struct WmmaGemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
typename Helper_::GemmConfig,
|
||||
// The stream to load A from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamA,
|
||||
// The stream to load B from global memory to shared memory.
|
||||
typename Helper_::GlobalLoadStreamB,
|
||||
// The stream to load A from shared memory.
|
||||
typename Helper_::SharedLoadStreamA,
|
||||
// The stream to load B from shared memory.
|
||||
typename Helper_::SharedLoadStreamB,
|
||||
// The epilogue.
|
||||
typename Helper_::Epilogue,
|
||||
// The block swizzle to reorganize the grid.
|
||||
IdentityBlockSwizzle,
|
||||
// The index.
|
||||
Index_,
|
||||
// The tool used to clear accumulators.
|
||||
typename Helper_::ClearAccumulators> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
318
cutlass/iterator_access.h
Normal file
318
cutlass/iterator_access.h
Normal file
@ -0,0 +1,318 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Free functions for loading and storing to implementations of tile iteartor concepts.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_load_store.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/shape.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.get(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < InputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
}
|
||||
}
|
||||
if (h < InputIterator::Iterations::kH - 1) {
|
||||
iterator.inc_h();
|
||||
}
|
||||
}
|
||||
if (d < InputIterator::Iterations::kD - 1) {
|
||||
iterator.inc_d();
|
||||
}
|
||||
}
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator, masked by a predicate iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = InputIterator::Tile::kC *
|
||||
(w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
|
||||
|
||||
Load<typename Fragment::Element, InputIterator::Tile::kC, InputIterator::kMemorySpace>::
|
||||
load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
InputIterator iterator(_iterator);
|
||||
iterator_load_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
|
||||
typename OutputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
if (iterator.valid(d, h, w, 0)) {
|
||||
iterator.set(reinterpret_cast<typename OutputIterator::AccessType const &>(
|
||||
frag_iterator.at(d, h, w, 0)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
0);
|
||||
}
|
||||
if (w < OutputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
}
|
||||
}
|
||||
if (h < OutputIterator::Iterations::kH - 1) {
|
||||
iterator.inc_h();
|
||||
}
|
||||
}
|
||||
if (d < OutputIterator::Iterations::kD - 1) {
|
||||
iterator.inc_d();
|
||||
}
|
||||
}
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Stores a fragment to a shared memory output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) {
|
||||
typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename OutputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentStore<OutputIterator::kIteratorFragment,
|
||||
OutputIterator::Tile::kC,
|
||||
typename OutputIterator::Scalar,
|
||||
OutputIterator::kMemorySpace,
|
||||
typename OutputIterator::FragmentElement,
|
||||
OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = OutputIterator::Tile::kC *
|
||||
(w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
|
||||
|
||||
Store<typename Fragment::Element,
|
||||
OutputIterator::Tile::kC,
|
||||
OutputIterator::kMemorySpace>::
|
||||
store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
OutputIterator iterator(_iterator);
|
||||
iterator_store_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
222
cutlass/load_store.h
Normal file
222
cutlass/load_store.h
Normal file
@ -0,0 +1,222 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines abstractions for efficiently loading and storing vectors to memory.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Enum to specify which memory space data resides in.
|
||||
*/
|
||||
struct MemorySpace {
|
||||
enum Kind {
|
||||
kGeneric, // Data accessed through pointer dereferencing
|
||||
kShared, // Data resides in shared memory
|
||||
kGlobal // Data resides in global memory
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
struct Load {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst = reinterpret_cast<AccessType const*>(&pointer[offset])[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<double, 2, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
|
||||
double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
|
||||
dst[0] = tmp.x;
|
||||
dst[1] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
|
||||
// WAR bug in NVCC where the upper and lower half of the register end up being the same
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<half, 8, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<half, 8>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
|
||||
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
|
||||
tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
|
||||
dst.registers[2] = tmp.x;
|
||||
dst.registers[3] = tmp.y;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
dst.registers[2] = tmp.z;
|
||||
dst.registers[3] = tmp.w;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
struct Store {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
pointer[offset] = src;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
|
||||
addr[0] = src.registers[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
|
||||
addr[0] = make_uint2(src.registers[0], src.registers[1]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Store<double, 2, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, double* pointer, int offset) {
|
||||
double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
|
||||
addr[0] = make_double2(src[0], src[1]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
|
||||
addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
48
cutlass/matrix_traits.h
Normal file
48
cutlass/matrix_traits.h
Normal file
@ -0,0 +1,48 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines properties of matrices used to denote layout and operands to GEMM kernels.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Describes layouts of matrices
|
||||
struct MatrixLayout {
|
||||
enum Kind { kRowMajor, kColumnMajor };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Gemm operand - D = A * B + C
|
||||
struct GemmOperand {
|
||||
enum Kind { kA, kB, kC, kD };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
493
cutlass/predicate_vector.h
Normal file
493
cutlass/predicate_vector.h
Normal file
@ -0,0 +1,493 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines container classes and iterators for managing a statically sized vector
|
||||
of boolean predicates.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/shape.h>
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup predicate_vector_concept Predicate Vector Concept
|
||||
@{
|
||||
|
||||
Implementations of \ref predicate_vector_concept contain an ordered set of boolean predicates which
|
||||
may be used as conditionals in other device-side operations. Both random access and iterators
|
||||
offering sequential access are provided.
|
||||
|
||||
@par Predicate Vector
|
||||
A \ref predicate_vector_concept satisfies the following expressions
|
||||
- <b>at(int idx)</b> - returns the value of the indexed predicate
|
||||
- <b>set(int idx, bool value)</b> - sets the value of the indexed predicate
|
||||
- <b>begin()</b> - returns a \ref predicate_iterator_concept pointing to the first predicate
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup predicate_iterator_concept Predicate Iterator Concept
|
||||
@{
|
||||
|
||||
Implementations of \ref predicate_iterator_concept enables accessing and traversing elements of a
|
||||
bit vector.
|
||||
|
||||
@par Const Predicate Iterator
|
||||
A const \ref predicate_iterator_concept satisfies the following expressions
|
||||
- <b>++it</b> increments the iterator to the next predicate
|
||||
- <b>*it</b> returns the value of the currently pointed-to predicate
|
||||
|
||||
@par Mutable Predicate Iterator
|
||||
A \ref predicate_iterator_concept that is non-const <b>also</b> satisfies the following expressions
|
||||
- <b>it.set(bool value)</b> sets the value of the currently pointed-to predicate
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept
|
||||
@{
|
||||
|
||||
Implementations of \ref predicate_tile_adapter provide a mapping between a the elements of a \ref
|
||||
tile_traits_concept and a \ref predicate_vector_concept.
|
||||
|
||||
@par Predicate Tile Adapter
|
||||
A \ref predicate_tile_adapter satisfies the following expressions
|
||||
- <b>at(int d, int h, int w, int c)</b> - returns the value of a predicate corresponding to the
|
||||
access (d, h, w, c) within the tile.
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
|
||||
template <
|
||||
/// Number of predicates conatined in predicate vector
|
||||
int kPredicates_,
|
||||
/// Number of predicates contained in each byte of internal storage
|
||||
int kPredicatesPerByte_ = 4,
|
||||
/// Location of first predicate within byte of internal storage
|
||||
int kPredicateStart_ = 0>
|
||||
struct PredicateVector {
|
||||
/// Number of bits stored by the PredicateVector
|
||||
static int const kPredicates = kPredicates_;
|
||||
|
||||
/// Number of bits stored within each byte of the predicate bit vector
|
||||
static int const kPredicatesPerByte = kPredicatesPerByte_;
|
||||
|
||||
/// First bit withing each byte containing predicates
|
||||
static int const kPredicateStart = kPredicateStart_;
|
||||
|
||||
// Make sure no one tries to put more than 8 bits in a byte :)
|
||||
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
|
||||
// Make sure the "offsetted" bits fit in one byte.
|
||||
static_assert(kPredicateStart + kPredicatesPerByte < 8,
|
||||
"The offsetted predicates must fit within an actual byte.");
|
||||
|
||||
/// Storage type of individual elements
|
||||
typedef uint32_t Storage;
|
||||
|
||||
/// Number of bytes needed
|
||||
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
||||
|
||||
/// Number of storage elements needed
|
||||
static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Words of bit vector
|
||||
Storage storageData[kWordCount];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Computes the word and bit corresponding to a logical predicate index
|
||||
CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
|
||||
CUTLASS_ASSERT(idx < kPredicates);
|
||||
|
||||
int byte = (idx / kPredicatesPerByte);
|
||||
int bit_offset = (idx % kPredicatesPerByte);
|
||||
|
||||
word = byte / sizeof(Storage);
|
||||
int byte_offset = (byte % sizeof(Storage));
|
||||
|
||||
bit = byte_offset * 8 + bit_offset + kPredicateStart;
|
||||
}
|
||||
|
||||
/// Accesses a given word with optional assertions
|
||||
CUTLASS_HOST_DEVICE Storage &storage(int word) {
|
||||
CUTLASS_ASSERT(word < kWordCount);
|
||||
return storageData[word];
|
||||
}
|
||||
|
||||
/// Accesses a given word with optional assertions
|
||||
CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
|
||||
CUTLASS_ASSERT(word < kWordCount);
|
||||
return storageData[word];
|
||||
}
|
||||
|
||||
public:
|
||||
//
|
||||
// Iterator
|
||||
//
|
||||
|
||||
/**
|
||||
* @brief A const iterator implementing \ref predicate_iterator_concept enabling sequential
|
||||
* read-only access to prediactes.
|
||||
* @concept{predicate_iterator_concept}
|
||||
*/
|
||||
class ConstIterator {
|
||||
/// Reference to PredicateVector instance
|
||||
PredicateVector const &vec_;
|
||||
|
||||
/// Index into PredicateVector
|
||||
int bit_;
|
||||
|
||||
public:
|
||||
/// Copy constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
|
||||
|
||||
/// Pre-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator++() {
|
||||
++bit_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Pre-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator--() {
|
||||
--bit_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Post-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator++(int) {
|
||||
ConstIterator ret(*this);
|
||||
ret.bit_++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Post-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator--(int) {
|
||||
ConstIterator ret(*this);
|
||||
ret.bit_--;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns true if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
|
||||
|
||||
/// Returns false if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
|
||||
|
||||
/// Dereferences iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator*() const { return vec_[bit_]; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief An iterator implementing \ref predicate_iterator_concept enabling sequential
|
||||
* read and write access to predicates.
|
||||
* @concept{predicate_iterator_concept}
|
||||
*/
|
||||
class Iterator {
|
||||
/// Reference to PredicateVector instance
|
||||
PredicateVector &vec_;
|
||||
|
||||
/// Index into PredicateVector
|
||||
int bit_;
|
||||
|
||||
public:
|
||||
/// Copy constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
|
||||
|
||||
/// Constructs an iterator from a PredicateVector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
|
||||
|
||||
/// Pre-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator &operator++() {
|
||||
++bit_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Pre-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator &operator--() {
|
||||
--bit_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Post-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator operator++(int) {
|
||||
Iterator ret(*this);
|
||||
ret.bit_++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Post-decrement
|
||||
CUTLASS_HOST_DEVICE
|
||||
Iterator operator--(int) {
|
||||
Iterator ret(*this);
|
||||
ret.bit_--;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns true if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
|
||||
|
||||
/// Returns false if iterators point to the same bit
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
|
||||
|
||||
/// Gets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool get() { return vec_[bit_]; }
|
||||
|
||||
/// Dereferences iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator*() const { return vec_[bit_]; }
|
||||
|
||||
/// Sets the bit at the pointed to location
|
||||
CUTLASS_HOST_DEVICE
|
||||
void set(bool value = true) { vec_.set(bit_, value); }
|
||||
};
|
||||
|
||||
/// Iterator that always returns true
|
||||
struct TrivialIterator {
|
||||
/// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TrivialIterator() {}
|
||||
|
||||
/// Copy constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TrivialIterator(Iterator const &it) {}
|
||||
|
||||
/// Constructs an iterator from a PredicateVector
|
||||
CUTLASS_HOST_DEVICE
|
||||
TrivialIterator(PredicateVector const &_vec) {}
|
||||
|
||||
/// Pre-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
TrivialIterator &operator++() { return *this; }
|
||||
|
||||
/// Post-increment
|
||||
CUTLASS_HOST_DEVICE
|
||||
TrivialIterator operator++(int) { return *this; }
|
||||
|
||||
/// Dereferences iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator*() const { return true; }
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Initialize the predicate vector
|
||||
CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
|
||||
|
||||
/// Fills all predicates with a given value
|
||||
CUTLASS_HOST_DEVICE void fill(bool value = true) {
|
||||
Storage item = (value ? ~Storage(0) : Storage(0));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kWordCount; ++i) {
|
||||
storage(i) = item;
|
||||
}
|
||||
}
|
||||
|
||||
/// Accesses a bit within the predicate vector.
|
||||
CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
|
||||
|
||||
/// Accesses a bit within the predicate vector.
|
||||
CUTLASS_HOST_DEVICE bool at(int idx) const {
|
||||
int bit, word;
|
||||
computeStorageOffset(word, bit, idx);
|
||||
|
||||
return ((storage(word) >> bit) & 1);
|
||||
}
|
||||
|
||||
/// Set a bit within the predicate vector.
|
||||
CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
|
||||
int bit, word;
|
||||
computeStorageOffset(word, bit, idx);
|
||||
|
||||
Storage disable_mask = (~(Storage(1) << bit));
|
||||
Storage enable_mask = (Storage(value) << bit);
|
||||
|
||||
storage(word) = ((storage(word) & disable_mask) | enable_mask);
|
||||
}
|
||||
|
||||
/// Computes the intersection of two identical predicate vectors.
|
||||
CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kWordCount; ++i) {
|
||||
storage(i) = (storage(i) & predicates.storage(i));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Computes the union of two identical predicate vectors.
|
||||
CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kWordCount; ++i) {
|
||||
storage(i) = (storage(i) | predicates.storage(i));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if entire predicate array is zero.
|
||||
CUTLASS_HOST_DEVICE bool is_zero() const {
|
||||
Storage mask(0);
|
||||
for (int byte = 0; byte < sizeof(Storage); ++byte) {
|
||||
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
|
||||
mask |= (byte_mask << (byte * 8));
|
||||
}
|
||||
uint32_t result = 0;
|
||||
for (int word = 0; word < kWordCount; ++word) {
|
||||
result |= storage(word);
|
||||
}
|
||||
return result == 0;
|
||||
}
|
||||
|
||||
/// Returns an iterator to the start of the bit vector
|
||||
CUTLASS_DEVICE
|
||||
Iterator begin() { return Iterator(*this); }
|
||||
|
||||
/// Returns an iterator
|
||||
CUTLASS_DEVICE
|
||||
Iterator end() { return Iterator(*this, kPredicates); }
|
||||
|
||||
/// Returns a ConstIterator
|
||||
CUTLASS_DEVICE
|
||||
ConstIterator const_begin() const { return ConstIterator(*this); }
|
||||
|
||||
/// Returns a ConstIterator
|
||||
CUTLASS_DEVICE
|
||||
ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Always returns true predicate.
|
||||
struct TrivialPredicateTileAdapter {
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE TrivialPredicateTileAdapter() {}
|
||||
|
||||
/// The value at location (d, h, w, c).
|
||||
CUTLASS_HOST_DEVICE bool at(int, int, int, int) const { return true; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to enable random access to predicates via logical coordinate within a tile.
|
||||
template <typename PredicateVector_, typename Iterations_>
|
||||
struct PredicateTileAdapter {
|
||||
/// The vector of predicates.
|
||||
typedef PredicateVector_ PredicateVector;
|
||||
/// The iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
private:
|
||||
/// The predicates.
|
||||
PredicateVector &predicates;
|
||||
|
||||
public:
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE PredicateTileAdapter(PredicateVector &predicates_) : predicates(predicates_) {}
|
||||
|
||||
/// Get the value at location (d, h, w, c).
|
||||
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
|
||||
return predicates.at(bit);
|
||||
}
|
||||
|
||||
/// Set the value at location (d, h, w, c).
|
||||
CUTLASS_DEVICE void set(int d, int h, int w, int c, bool value) {
|
||||
int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
|
||||
predicates.set(bit, value);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to enable random access to predicates via logical coordinate within a tile.
|
||||
template <typename PredicateVector_, typename Iterations_>
|
||||
struct ConstPredicateTileAdapter {
|
||||
/// The vector of predicates.
|
||||
typedef PredicateVector_ PredicateVector;
|
||||
/// The iterations.
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
private:
|
||||
/// The predicates.
|
||||
PredicateVector const &predicates;
|
||||
|
||||
public:
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ConstPredicateTileAdapter(PredicateVector const &predicates_)
|
||||
: predicates(predicates_) {}
|
||||
|
||||
/// Get the value at location (d, h, w, c).
|
||||
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
|
||||
return predicates.at(bit);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
58
cutlass/reshape_tile.h
Normal file
58
cutlass/reshape_tile.h
Normal file
@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a type for restructuring a tile.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/shape.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The following functor reshapes a tile of data. The goal is to have at least kAccessSize in
|
||||
// the inner-most dimension. If the user respects that constraint, there is nothing to be done. If
|
||||
// that's not the case, this functor will correct that and "extract" the right number of elements
|
||||
// from the next dimension.
|
||||
|
||||
template <typename Tile_, int kAccessSize_, bool = (Tile_::kC < kAccessSize_)>
|
||||
struct ReshapeTile {
|
||||
typedef Tile_ Tile;
|
||||
};
|
||||
|
||||
template <typename Tile_, int kAccessSize_>
|
||||
struct ReshapeTile<Tile_, kAccessSize_, true> {
|
||||
// Make sure the W dimension of the tile is large enough.
|
||||
static_assert(Tile_::kW >= kAccessSize_, "The W dimension is too small");
|
||||
// Make sure the dimension can be divided by the number of scalars.
|
||||
static_assert(Tile_::kW % kAccessSize_ == 0, "Not supported");
|
||||
// Collapse the W dimension.
|
||||
typedef Shape<Tile_::kD, Tile_::kH, Tile_::kW / kAccessSize_, kAccessSize_> Tile;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
305
cutlass/shape.h
Normal file
305
cutlass/shape.h
Normal file
@ -0,0 +1,305 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup layout_concept Layout Concept
|
||||
* @{
|
||||
* @par Implementations of \ref layout_concept are used to describe a cube with DxHxW elements and C
|
||||
scalars per element.
|
||||
A HxW slice of a cube is called an image and a cube consists of D images.
|
||||
*
|
||||
* @par Notations
|
||||
* Let Layout be an implementation of the \ref layout_concept.
|
||||
*
|
||||
* @par Valid Expressions
|
||||
* - <b>Layout::D</b> specifies the depth of a cube
|
||||
* - <b>Layout::H</b> specifies the height of a cube
|
||||
* - <b>Layout::W</b> specifies the height of a cube
|
||||
* - <b>Layout::C</b> specifies the number of channels of each element in a cube
|
||||
* - <b>Layout::W_c</b> specifies the number of scalars of each row in one image of a cube.
|
||||
* - <b>Layout::H_w</b> specifies the number of elements in an image slice.
|
||||
* - <b>Layout::H_w_c</b>_specifies the number of scalars in an image slice.
|
||||
* - <b>Layout::D_h_w</b> specifies the number of elements in a cube.
|
||||
* - <b>Layout::D_h_w_c</b> specifies the number of scalars in a cube.
|
||||
* - <b>Layout::Strides</b> is a \ref layout_concept specifying the strides.
|
||||
* @}
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief A Shape implementing \ref layout_concept describing the dimensions of a cube.
|
||||
* @concept{layout_concept}
|
||||
*/
|
||||
template <int kD_ = 1, int kH_ = 1, int kW_ = 1, int kC_ = 1>
|
||||
struct Shape {
|
||||
/// The depth of the cube.
|
||||
static int const kD = kD_;
|
||||
/// The height of the cube.
|
||||
static int const kH = kH_;
|
||||
/// The width of the cube.
|
||||
static int const kW = kW_;
|
||||
/// The number of scalars per element.
|
||||
static int const kC = kC_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute derived counted of a \ref layout_concept based class
|
||||
*/
|
||||
template <typename Shape>
|
||||
struct ShapeCount {
|
||||
/// The number of elements per row.
|
||||
static int const kWc = Shape::kW * Shape::kC;
|
||||
/// The number of pixels per image.
|
||||
static int const kHw = Shape::kH * Shape::kW;
|
||||
/// The number of elements per image.
|
||||
static int const kHwc = Shape::kH * kWc;
|
||||
/// The number of pixels per cube.
|
||||
static int const kDhw = Shape::kD * kHw;
|
||||
/// The number of elements in the 4D space.
|
||||
static int const kDhwc = Shape::kD * kHwc;
|
||||
/// The number of elements in the 4D space.
|
||||
static int const kCount = kDhwc;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, int kScale_>
|
||||
struct ShapeScale {
|
||||
typedef Shape<A_::kD * kScale_, A_::kH * kScale_, A_::kW * kScale_, A_::kC * kScale_> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeAdd {
|
||||
typedef Shape<A_::kD + B_::kD, A_::kH + B_::kH, A_::kW + B_::kW, A_::kC + B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeSub {
|
||||
typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMul {
|
||||
typedef Shape<A_::kD * B_::kD, A_::kH * B_::kH, A_::kW * B_::kW, A_::kC * B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeDiv {
|
||||
typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC> Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMax {
|
||||
typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
|
||||
(A_::kH > B_::kH ? A_::kH : B_::kH),
|
||||
(A_::kW > B_::kW ? A_::kW : B_::kW),
|
||||
(A_::kC > B_::kC ? A_::kC : B_::kC)>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMin {
|
||||
typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
|
||||
(A_::kH < B_::kH ? A_::kH : B_::kH),
|
||||
(A_::kW < B_::kW ? A_::kW : B_::kW),
|
||||
(A_::kC < B_::kC ? A_::kC : B_::kC)>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_, int kElementsPerAccess>
|
||||
struct ShapeStrides {
|
||||
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
|
||||
Shape_::kW * Shape_::kC,
|
||||
Shape_::kC,
|
||||
kElementsPerAccess>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube
|
||||
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
|
||||
*/
|
||||
template <typename Shape_>
|
||||
struct ComputeOffsetFromShape {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
// clang-format off
|
||||
return d * Shape_::kH * Shape_::kW * Shape_::kC +
|
||||
h * Shape_::kW * Shape_::kC +
|
||||
w * Shape_::kC +
|
||||
c;
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
* @tparam kSc Separation between two elements in "elements"
|
||||
*/
|
||||
template <int kSh_, int kSw_, int kSc_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, kSc_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * kSw_ * kSc_ + w * kSc_ + c;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
*/
|
||||
template <int kSh_, int kSw_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * kSw_ + w; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube
|
||||
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
|
||||
*/
|
||||
template <typename Strides_>
|
||||
struct ComputeOffsetFromStrides {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
* @tparam S_c Stride between two scalars.
|
||||
*/
|
||||
template <int S_h_, int S_w_, int S_c_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, S_c_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * S_h_ + w * S_w_ + c * S_c_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
*/
|
||||
template <int S_h_, int S_w_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * S_h_ + w * S_w_; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_.
|
||||
* Afterwards compute the offset of those coordinates using Strides_
|
||||
* @tparam Threads_ The dimension of the cube the threadIdx.x value is mapped on
|
||||
* @tparam Strides_ The strides to use when compute the offsets based on the coordinates of the cube.
|
||||
*/
|
||||
template <typename Threads_, typename Strides_>
|
||||
struct ComputeThreadOffsetFromStrides {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int c = threadIdx.x % Threads_::kC;
|
||||
int w = threadIdx.x / Threads_::kC % Threads_::kW;
|
||||
int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
|
||||
int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
|
||||
|
||||
// Compute the offset.
|
||||
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*@brief Specialization for D=1
|
||||
*/
|
||||
template <int T_h_, int T_w_, int T_c_, int S_h_, int S_w_, int S_c_>
|
||||
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, T_c_>, Shape<1, S_h_, S_w_, S_c_> > {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int c = threadIdx.x % T_c_;
|
||||
int w = threadIdx.x / T_c_ % T_w_;
|
||||
int h = threadIdx.x / T_c_ / T_w_ % T_h_;
|
||||
|
||||
// Compute the offset.
|
||||
return h * S_h_ + w * S_w_ + c * S_c_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
*@brief Specialization for D=1 and C=1
|
||||
*/
|
||||
template <int T_h_, int T_w_, int S_h_, int S_w_>
|
||||
struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, 1>, Shape<1, S_h_, S_w_, 1> > {
|
||||
static CUTLASS_DEVICE int get() {
|
||||
// Decompose the thread index.
|
||||
int w = threadIdx.x % T_w_;
|
||||
int h = threadIdx.x / T_w_;
|
||||
|
||||
// Compute the offset.
|
||||
return h * S_h_ + w * S_w_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
151
cutlass/tensor_ref.h
Normal file
151
cutlass/tensor_ref.h
Normal file
@ -0,0 +1,151 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure modeling a pointer and stride into a tensor
|
||||
template <typename Storage_, int Rank_>
|
||||
class TensorRef {
|
||||
public:
|
||||
/// Data type of individual access
|
||||
typedef Storage_ Storage;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = Rank_;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to storage element
|
||||
Storage* ptr_;
|
||||
|
||||
/// Stride information
|
||||
Coord<Rank> stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef() : ptr_(nullptr) {}
|
||||
|
||||
/// Constructs from a pointer, size, and stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, Coord<Rank> stride) : ptr_(ptr), stride_(stride) {}
|
||||
|
||||
/// Updates the pointer, stride, and location within a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr = nullptr, Coord<Rank> stride = Coord<Rank>(0)) {
|
||||
ptr_ = ptr;
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
/// Conversion function
|
||||
template <typename T>
|
||||
TensorRef<T, Rank> convert() {
|
||||
Coord<Rank> converted_stride;
|
||||
for (int i = 0; i < Rank - 1; ++i) {
|
||||
converted_stride[i] = stride_[i] * Extent<Storage>::kValue / Extent<T>::kValue;
|
||||
}
|
||||
converted_stride[Rank - 1] = stride_[Rank - 1];
|
||||
|
||||
return TensorRef<T, Rank>(reinterpret_cast<T*>(ptr_), converted_stride);
|
||||
}
|
||||
|
||||
/// Returns true if the TensorRef may be safely accessed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ptr_ != nullptr; }
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage* data() const { return ptr_; }
|
||||
|
||||
/// Returns the stride of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank> const& stride() const { return stride_; }
|
||||
|
||||
/// Returns the stride of the tensor in the given dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return stride_.at(dim); }
|
||||
|
||||
/// Returns the maximum stride element as the 'leading dimension'
|
||||
CUTLASS_HOST_DEVICE
|
||||
int leading_dim() const { return __NV_STD_MAX(stride_[1], stride_[2]); }
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
long long offset(Coord<Rank> const& coord) const {
|
||||
return stride_.template dot<long long>(coord);
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(Coord<Rank> const& coord) const { return ptr_[offset(coord)]; }
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(int idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](int idx) const { return at(idx); }
|
||||
|
||||
/// Adds an offset to the pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& advance(Coord<Rank> const& b) {
|
||||
ptr_ += offset(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator+(Coord<Rank> const& b) const { return TensorRef(ptr_ + offset(b), stride_); }
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator-(Coord<Rank> const& b) const { return TensorRef(ptr_ - offset(b), stride_); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
172
cutlass/tensor_view.h
Normal file
172
cutlass/tensor_view.h
Normal file
@ -0,0 +1,172 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides and a pointer to tensor data.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Host-side reference implementation of tensor operations
|
||||
template <typename T>
|
||||
class TensorView : public TensorRef<T, 4> {
|
||||
public:
|
||||
/// Reference and stride
|
||||
typedef TensorRef<T, 4> Base;
|
||||
|
||||
/// Reference and stride
|
||||
typedef Base TensorRef_t;
|
||||
|
||||
/// Reference to constant type
|
||||
typedef TensorRef<T const, 4> ConstTensorRef_t;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = TensorRef_t::Rank;
|
||||
|
||||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef int Offset_t;
|
||||
|
||||
/// Coordinate into tensor
|
||||
typedef Coord<Rank> Coord_t;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to pitch-linear memory
|
||||
TensorRef_t ref_;
|
||||
|
||||
/// Dimensions of coordinate (independent of stride)
|
||||
Coord_t size_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Device and Host Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView() {}
|
||||
|
||||
/// Constructs a Tensor_view from a TensorRef and size
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView(TensorRef_t const& _ref, Coord_t const& _size) : Base(_ref), size_(_size) {}
|
||||
|
||||
/// Returns true if the Tensor_view is bound to some memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ref().good(); }
|
||||
|
||||
/// Returns a pointer to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
T* data() const { return ref().data(); }
|
||||
|
||||
/// Updates the reference and size of a Tensor_view object
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
|
||||
Base::operator=(_ref);
|
||||
size_ = _size;
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t& ref() { return *this; }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef_t const_ref() { return ConstTensorRef_t(data(), stride()); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t const& ref() const { return *this; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& size() const { return size_; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
int size(int dim) const { return size_.at(dim); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& stride() const { return ref().stride(); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return ref().stride(dim); }
|
||||
|
||||
/// Assigns the Tensor_view
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator=(TensorView const& _tensor) {
|
||||
Base::operator=(_tensor._ref);
|
||||
size_ = _tensor.size_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the index of an element
|
||||
CUTLASS_HOST_DEVICE
|
||||
Offset_t offset(Coord_t const& coord) const { return ref().offset(coord); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool contains(Coord_t const& coord) const {
|
||||
for (int dim = 0; dim < Rank; ++dim) {
|
||||
if (coord.at(dim) >= size_.at(dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Coord_t const& coord) const { return ref().at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
T& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Offset_t idx) const { return ref().at(idx); }
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView<T> subview(Coord_t const& location, Coord_t size) const {
|
||||
return TensorView<T>(ref() + location, size.clamp(size_ - location));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
899
cutlass/tile_iterator.h
Normal file
899
cutlass/tile_iterator.h
Normal file
@ -0,0 +1,899 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines the Tile Traits concept and iterators for loading and storing to tiles
|
||||
efficiently.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_traits_concept Tile Traits Concept
|
||||
@{
|
||||
|
||||
\ref tile_traits_concept is a type definining the shape of a tile and the distribution of accesses
|
||||
by individual entities, either threads or other.
|
||||
|
||||
@par Tile Traits Concept
|
||||
Types satisfying \ref tile_traits_concept define the following members
|
||||
- <b>Tile</b> - a type satisfying \ref layout_concept describing the dimensions of the tile
|
||||
- <b>Delta</b> - a type satisfying \ref layout_concept describing the increments between accesses
|
||||
along each dimension
|
||||
- <b>Iterations</b> - a type satisfying \ref layout_concept describing the number of accesses
|
||||
along each dimension
|
||||
- <b>Offset</b> - the type of a <i>functor</i> computing the offset of each participating entity
|
||||
as a Coord<4>.
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specifies dimension in which post-increment accesses advance
|
||||
struct IteratorAdvance {
|
||||
enum Kind { kD, kH, kW };
|
||||
};
|
||||
|
||||
/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix
|
||||
struct IteratorFragment {
|
||||
enum Kind { kScalar, kWmmaMatrix };
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref tile_traits_concept
|
||||
* @concept{tile_traits_concept}
|
||||
*/
|
||||
template <typename Tile_,
|
||||
typename Delta_,
|
||||
typename Iterations_,
|
||||
typename ThreadOffset_,
|
||||
int kAccessSize>
|
||||
struct TileTraits {
|
||||
/// Shape of the tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of steps between accesses along each dimension
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// Number of accesses performed
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Functor that returns the logical coordinate of each entity's initial offset in the tile
|
||||
typedef ThreadOffset_ ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Iterator for accessing a stripmined tile in memory
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileIteratorBase {
|
||||
/// concept TileTraits
|
||||
typedef Traits_ Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef FragmentElement_ FragmentElement;
|
||||
|
||||
/// Specifies dimension in which post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Advance_;
|
||||
|
||||
/// Specifies iterator storage fragment type (Scalar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = IteratorFragment_;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef Skew_ Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Traits::Tile Tile;
|
||||
|
||||
/// Distance along each dimension
|
||||
typedef typename Traits::Delta Delta;
|
||||
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef typename Traits::ImmediateOffsetStrides ImmediateOffsetStrides;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// Thread offset
|
||||
typedef typename Traits::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The number of scalars accessed per load/store.
|
||||
static int const kAccessSize = Tile::kC;
|
||||
|
||||
/// The elements loaded/store by one instruction.
|
||||
typedef typename Vectorize<FragmentElement, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The size of storage needed per fragment
|
||||
static int const kFragmentSize =
|
||||
(kIteratorFragment == IteratorFragment::kWmmaMatrix ? 16 : sizeof(AccessType));
|
||||
/// The storage.
|
||||
typedef Fragment<Scalar, ShapeCount<Tile>::kCount, kFragmentSize> Storage;
|
||||
/// The fragment.
|
||||
typedef Fragment<FragmentElement, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
/// The fragment const iterator.
|
||||
typedef FragmentConstIterator<Fragment, Iterations, AccessType> FragmentConstIterator;
|
||||
/// The shape of the fragment.
|
||||
typedef typename FragmentIterator::FragmentShape FragmentShape;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef PredicateVector<ShapeCount<Iterations>::kCount> PredicateVector;
|
||||
|
||||
//
|
||||
// Params struct
|
||||
//
|
||||
|
||||
/// Parameters to the iterator
|
||||
struct Params {
|
||||
Index stride_d;
|
||||
Index stride_h;
|
||||
Index stride_w;
|
||||
|
||||
Index inc_d;
|
||||
Index inc_h;
|
||||
Index inc_w;
|
||||
|
||||
Index inc_advance;
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_d = _inc_d;
|
||||
inc_h = _inc_h;
|
||||
inc_w = _inc_w;
|
||||
inc_advance = _inc_advance;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_w = stride_w * Delta::kW;
|
||||
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
// Advance in the H dimension.
|
||||
inc_d = 0;
|
||||
} else if (kAdvance == IteratorAdvance::kW) {
|
||||
// Advance in the W dimension.
|
||||
inc_d = stride_w * Tile::kW - stride_h * Tile::kH;
|
||||
} else {
|
||||
// Advance in the D dimension.
|
||||
inc_d = stride_d;
|
||||
}
|
||||
|
||||
inc_advance = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE int initialize() {
|
||||
stride_d = 0;
|
||||
stride_h = 0;
|
||||
stride_w = 1;
|
||||
|
||||
inc_d = inc_h = inc_w = inc_advance = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &offset = make_Coord(0, 0, 0)) {
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
bool enable_d = (d * Delta::kD + offset[0] < bounds[0]);
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
bool enable_h = (h * Delta::kH + offset[1] < bounds[1]);
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]);
|
||||
predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_load_iterator_concept Tile Load Iterator Concept
|
||||
@{
|
||||
|
||||
\ref tile_load_iterator_concept enables loading a tile from addressable memory into a fragment
|
||||
|
||||
@par Tile Load Iterator Concept
|
||||
Types satisfying \ref tile_load_iterator_concept define the following members
|
||||
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
|
||||
each access implied by the tile traits
|
||||
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
|
||||
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
|
||||
vector according to externally specified bounds
|
||||
- <b>load_post_increment(fragment, pred_it)</b> - a method that loads a fragment and increments
|
||||
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
|
||||
- <b>load_post_increment(fragment)</b> - a method that loads a fragment and increments the
|
||||
iterator to the next tile
|
||||
- <b>load(fragment, pred_it)</b> - a const method that loads a fragment, guarded by a \ref
|
||||
predicate_iterator_concept
|
||||
- <b>load(fragment)</b> - a method that loads a fragment
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief An iterator implementing \ref tile_load_iterator_concept for loading a tile from memory
|
||||
* @concept{tile_load_iterator_concept}
|
||||
*/
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
/// concept TileTraits
|
||||
typedef typename Base::Traits Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef typename Base::FragmentElement FragmentElement;
|
||||
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Base::Tile Tile;
|
||||
|
||||
/// Delta
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
/// ThreadOffset functor
|
||||
typedef typename Base::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Fragment type
|
||||
typedef typename Base::FragmentShape FragmentShape;
|
||||
|
||||
/// Memory access type
|
||||
typedef typename Base::AccessType AccessType;
|
||||
|
||||
/// Fragment definition
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Fragment iterator definition
|
||||
typedef typename Base::FragmentIterator FragmentIterator;
|
||||
|
||||
/// Fragment const iterator definition
|
||||
typedef typename Base::FragmentConstIterator FragmentConstIterator;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef typename Base::PredicateVector PredicateVector;
|
||||
|
||||
/// Storage object that may be loaded from
|
||||
typedef typename Base::Storage SharedStorage;
|
||||
|
||||
/// IteratorBase parameters
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// Do we require a fence?
|
||||
enum { kRequiresLoadFence = Tile::kD == 1 };
|
||||
|
||||
/// The pointer type
|
||||
typedef Scalar const *Pointer;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar const *pointer;
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage const &storage) {
|
||||
pointer = &storage[0];
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
pointer = ptr;
|
||||
Base::Params::initialize(
|
||||
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes params to default values
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize() { return Base::Params::initialize(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
Params params;
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Stage argument enables wrapping after some number of tiles have been loaded.
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
0,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator() {}
|
||||
|
||||
/// Constructs a tile load iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator(Params const &_params,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
Index block_offset_h = 0;
|
||||
Index block_offset_w = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_offset_h = block_offset[1];
|
||||
block_offset_w = block_offset[2];
|
||||
} else {
|
||||
block_offset_h = block_offset[2];
|
||||
block_offset_w = block_offset[1];
|
||||
}
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
(block_offset_h + thread_offset[1]) * params.stride_h +
|
||||
(block_offset_w + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
}
|
||||
|
||||
/// Constructs a tile load iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *data() const { return params.pointer; }
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar, Base::kAccessSize, kMemorySpace>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
|
||||
/// Increment in the H dimension
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
|
||||
/// Increment in the W dimension
|
||||
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
|
||||
|
||||
/// Increment in the next dimension
|
||||
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
params.pointer -= (Tile::kD - 1) * kStageSize;
|
||||
stage = 0;
|
||||
} else {
|
||||
params.pointer += kStageSize;
|
||||
stage = stage + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Load<typename Fragment::Element, Tile::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
}
|
||||
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
TileLoadIterator _load_it(*this);
|
||||
_load_it.load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load(fragment, pred_it);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup tile_store_iterator_concept Tile Store Iterator Concept
|
||||
@{
|
||||
|
||||
\ref tile_store_iterator_concept enables storing a tile to addressable memory
|
||||
|
||||
@par Tile Store Iterator Concept
|
||||
Types satisfying \ref tile_load_iterator_concept define the following members
|
||||
- <b>PredicateVector</b> - a \ref predicate_vector_concept with sufficient predicate storage for
|
||||
each access implied by the tile traits
|
||||
- <b>Fragment</b> - the destination fragment type satisfying \ref fragment_concept
|
||||
- <b>initialize_predicates(pred_it, bounds, block_offset)</b> - function initializing a predicate
|
||||
vector according to externally specified bounds
|
||||
- <b>store_post_increment(fragment, pred_it)</b> - a method that stores a fragment and increments
|
||||
the iterator to the next tile, guarded by a \ref predicate_iterator_concept
|
||||
- <b>store_post_increment(fragment)</b> - a method that stores a fragment and increments the
|
||||
iterator to the next tile
|
||||
- <b>store(fragment, pred_it)</b> - a const method that stores a fragment, guarded by a \ref
|
||||
predicate_iterator_concept
|
||||
- <b>store(fragment)</b> - a method that loads a fragment
|
||||
|
||||
@}
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief An iterator implementing \ref tile_store_iterator_concept for storing a tile to memory
|
||||
* @concept{tile_store_iterator_concept}
|
||||
*/
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
IteratorAdvance::Kind Advance_ = IteratorAdvance::kH,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
Advance_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
/// concept TileTraits
|
||||
typedef typename Base::Traits Traits;
|
||||
|
||||
/// Scalar element
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef typename Base::FragmentElement FragmentElement;
|
||||
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Skew quantity
|
||||
typedef typename Base::Skew Skew;
|
||||
|
||||
/// Tile shape
|
||||
typedef typename Base::Tile Tile;
|
||||
|
||||
/// Delta
|
||||
typedef typename Base::Delta Delta;
|
||||
|
||||
/// Iterations
|
||||
typedef typename Base::Iterations Iterations;
|
||||
|
||||
/// ThreadOffset functor
|
||||
typedef typename Base::ThreadOffset ThreadOffset;
|
||||
|
||||
/// Fragment type
|
||||
typedef typename Base::FragmentShape FragmentShape;
|
||||
|
||||
/// Memory access type
|
||||
typedef typename Base::AccessType AccessType;
|
||||
|
||||
/// Fragment definition
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Fragment iterator definition
|
||||
typedef typename Base::FragmentIterator FragmentIterator;
|
||||
|
||||
/// Fragment const iterator definition
|
||||
typedef typename Base::FragmentConstIterator FragmentConstIterator;
|
||||
|
||||
/// Default predicate mask type
|
||||
typedef typename Base::PredicateVector PredicateVector;
|
||||
|
||||
/// Storage object which may be stored to
|
||||
typedef typename Base::Storage SharedStorage;
|
||||
|
||||
/// IteratorBase parameters
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar *pointer;
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage &storage) {
|
||||
pointer = &storage[0];
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr,
|
||||
Index _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
pointer = ptr;
|
||||
Base::Params::initialize(
|
||||
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to default values
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize() { return Base::Params::initialize(); }
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
Params params;
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// The stage.
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
0,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
}
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator() {}
|
||||
|
||||
/// Constructs a tile store iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator(Params const &_params,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
(block_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(block_offset[2] + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
}
|
||||
|
||||
/// Constructs a tile store iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar *data() const { return params.pointer; }
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
|
||||
/// Increment in the H dimension
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
|
||||
/// Increment in the W dimension
|
||||
CUTLASS_HOST_DEVICE void inc_w() { params.pointer += params.inc_w; }
|
||||
|
||||
/// Increment in the next dimension
|
||||
CUTLASS_HOST_DEVICE void inc_advance() {}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
params.pointer -= (Tile::kD - 1) * kStageSize;
|
||||
stage = 0;
|
||||
} else {
|
||||
params.pointer += kStageSize;
|
||||
stage = stage + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(AccessType const &value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar, Base::kAccessSize, kMemorySpace>::store(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Store<typename Fragment::Element, Tile::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
}
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
TileStoreIterator _store_it(*this);
|
||||
_store_it.store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store(fragment, pred_it);
|
||||
}
|
||||
};
|
||||
}
|
||||
238
cutlass/tile_traits_standard.h
Normal file
238
cutlass/tile_traits_standard.h
Normal file
@ -0,0 +1,238 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines tile traits for several tile partitioning arrangements of threads expected to
|
||||
achieve efficient streaming performance.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/tile_iterator.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Basic thread offset function computed from a thread shape
|
||||
template <typename ThreadShape>
|
||||
struct TiledThreadOffset {
|
||||
/// Computes the logical coordinate from thread shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
Coord<4> thread_offset;
|
||||
|
||||
int index = threadIdx.x;
|
||||
|
||||
thread_offset[3] = (index % ThreadShape::kC);
|
||||
index = (index / ThreadShape::kC);
|
||||
|
||||
thread_offset[2] = (index % ThreadShape::kW);
|
||||
index = (index / ThreadShape::kW);
|
||||
|
||||
thread_offset[1] = (index % ThreadShape::kH);
|
||||
index = (index / ThreadShape::kH);
|
||||
|
||||
thread_offset[0] = index;
|
||||
|
||||
return thread_offset;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which the number of threads is greater than the
|
||||
/// contiguous dimension of the tile.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsStrideMajor {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(Tile::kW <= kThreads,
|
||||
"This specialization assumes there are more threads than the contiguous dimension "
|
||||
"of the tile.");
|
||||
|
||||
/// Shape of threads
|
||||
typedef Shape<1, kThreads / Tile::kW, Tile::kW, 1> ThreadShape;
|
||||
|
||||
/// Delta along each dimension
|
||||
typedef Shape<1, ThreadShape::kH, 1, 1> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH / ThreadShape::kH, 1, 1> Iterations;
|
||||
|
||||
/// Computes the initial offset
|
||||
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which the number of threads is fewer than the tile size
|
||||
/// in the contiguous dimension.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsContiguousMajor {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
// Static assertions
|
||||
static_assert(Tile::kW >= kThreads,
|
||||
"This specialization assumes there are more threads than the contiguous dimension "
|
||||
"of the tile.");
|
||||
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(!(Tile::kW % kThreads),
|
||||
"The contiguous size of the tile must be divisible by the number of threads.");
|
||||
|
||||
/// Thread shape
|
||||
typedef Shape<1, 1, kThreads> ThreadShape;
|
||||
|
||||
/// Delta between each thread's access
|
||||
typedef Shape<1, 1, kThreads> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH, Tile::kW / kThreads> Iterations;
|
||||
|
||||
/// Computes the initial offset
|
||||
typedef TiledThreadOffset<ThreadShape> ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tiling in which warps rake across the contiguous dimension
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsWarpRake {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
/// Hard-coded warp size
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
static_assert(!(kThreads % kWarpSize), "Number of threads must be divisible by the warp size.");
|
||||
|
||||
static_assert(!(Tile::kW % kWarpSize), "Contiguous dimension must be divisible by the warp size");
|
||||
|
||||
/// Warps strip-mined across strided dimension
|
||||
static int const kWarpsStrided = __NV_STD_MIN(kWarpCount, Tile::kH);
|
||||
|
||||
/// Warps stripmined contiguous dimension
|
||||
static int const kWarpsContiguous = kWarpCount / kWarpsStrided;
|
||||
|
||||
/// Arrangement of threads
|
||||
typedef Shape<1, kWarpsStrided, kWarpsContiguous * kWarpSize> ThreadShape;
|
||||
|
||||
/// The same warp rakes along the contiguous dimension
|
||||
typedef Shape<1, kWarpsStrided, kWarpSize> Delta;
|
||||
|
||||
/// Number of iterations
|
||||
typedef Shape<1, Tile::kH / Delta::kH, Tile::kW / ThreadShape::kW> Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
/// Basic thread offset function computed from a thread shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<4> operator()() const {
|
||||
int tid = threadIdx.x;
|
||||
int warp = (tid / kWarpSize);
|
||||
int lane = (tid % kWarpSize);
|
||||
|
||||
static int const kWarpSpanContiguous = kWarpSize * Iterations::kW;
|
||||
|
||||
int warp_w = (warp % kWarpsContiguous);
|
||||
int warp_h = (warp / kWarpsContiguous);
|
||||
|
||||
return make_Coord(0, warp_h, lane + kWarpSpanContiguous * warp_w, 0);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Chooses 'best' shape to enable warp raking along contiguous dimension if possible.
|
||||
template <typename Tile_, int Threads>
|
||||
struct TileTraitsStandard {
|
||||
/// Shape of tile
|
||||
typedef Tile_ Tile;
|
||||
|
||||
/// Number of participating threads
|
||||
static int const kThreads = Threads;
|
||||
|
||||
/// Hard-coded warp size
|
||||
static int const kWarpSize = 32;
|
||||
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
|
||||
/// Choose the stride-major contiguous tiling if the contiguous dimension is
|
||||
/// smaller than the warp size. Otherwise, if it is divisible by the warp size,
|
||||
/// choose the warp rake arrangement.
|
||||
typedef typename platform::conditional <
|
||||
Tile::kW<kWarpSize,
|
||||
TileTraitsStrideMajor<Tile, Threads>,
|
||||
typename platform::conditional<!(Tile::kW % kWarpSize),
|
||||
TileTraitsWarpRake<Tile, Threads>,
|
||||
TileTraitsContiguousMajor<Tile, Threads> >::type>::
|
||||
type Traits;
|
||||
|
||||
/// Delta between accesses
|
||||
typedef typename Traits::Delta Delta;
|
||||
|
||||
/// Delta between each thread's access
|
||||
/// TODO MTA this is wrong for sure, but Delta is used for stride computation at the moment
|
||||
typedef Delta ImmediateOffsetStrides;
|
||||
|
||||
/// Number of accesses
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
/// Thread offset functor
|
||||
typedef typename Traits::ThreadOffset ThreadOffset;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
131
cutlass/util/cutlass_math.h
Normal file
131
cutlass/util/cutlass_math.h
Normal file
@ -0,0 +1,131 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Math utilities
|
||||
*/
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* Static math utilities
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Statically determine if N is a power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded down
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_down {
|
||||
/// Static logarithm value
|
||||
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_down<N, 1, Count> {
|
||||
enum { value = Count };
|
||||
};
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded up
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_up {
|
||||
/// Static logarithm value
|
||||
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_up<N, 1, Count> {
|
||||
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
|
||||
};
|
||||
|
||||
/**
|
||||
* Statically estimate sqrt(N) to the nearest power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct sqrt_est {
|
||||
enum { value = 1 << (log2_up<N>::value / 2) };
|
||||
};
|
||||
|
||||
/**
|
||||
* For performing a constant-division with a compile-time assertion that the
|
||||
* Divisor evenly-divides the Dividend.
|
||||
*/
|
||||
template <int Dividend, int Divisor>
|
||||
struct divide_assert {
|
||||
enum { value = Dividend / Divisor };
|
||||
|
||||
static_assert((Dividend % Divisor == 0), "Not an even multiple");
|
||||
};
|
||||
|
||||
/******************************************************************************
|
||||
* Rounding
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Round dividend up to the nearest multiple of divisor
|
||||
*/
|
||||
template <typename dividend_t, typename divisor_t>
|
||||
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
|
||||
return ((dividend + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Greatest common divisor
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
|
||||
for (;;) {
|
||||
if (a == 0) return b;
|
||||
b %= a;
|
||||
if (b == 0) return a;
|
||||
a %= b;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Least common multiple
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
|
||||
value_t temp = gcd(a, b);
|
||||
|
||||
return temp ? (a / temp * b) : 0;
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,29 +1,27 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -44,87 +42,81 @@ namespace cutlass {
|
||||
* Formats and prints the given message to stdout
|
||||
*/
|
||||
#if !defined(CUDA_LOG)
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
|
||||
#else
|
||||
inline __host__ __device__ unsigned get_threadidx_x() { return threadIdx.x; }
|
||||
inline __host__ __device__ unsigned get_threadidx_y() { return threadIdx.y; }
|
||||
inline __host__ __device__ unsigned get_threadidx_z() { return threadIdx.z; }
|
||||
inline __host__ __device__ unsigned get_blockidx_x() { return blockIdx.x; }
|
||||
inline __host__ __device__ unsigned get_blockidx_y() { return blockIdx.y; }
|
||||
inline __host__ __device__ unsigned get_blockidx_z() { return blockIdx.z; }
|
||||
#define CUDA_LOG(format, ...) \
|
||||
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
|
||||
get_blockidx_x(), get_blockidx_y(), get_blockidx_z(), \
|
||||
get_threadidx_x(), get_threadidx_y(), get_threadidx_z(), \
|
||||
__VA_ARGS__);
|
||||
#endif
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__)
|
||||
#else
|
||||
#define CUDA_LOG(format, ...) \
|
||||
printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \
|
||||
blockIdx.x, \
|
||||
blockIdx.y, \
|
||||
blockIdx.z, \
|
||||
threadIdx.x, \
|
||||
threadIdx.y, \
|
||||
threadIdx.z, \
|
||||
__VA_ARGS__);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Formats and prints the given message to stdout only if DEBUG is defined
|
||||
*/
|
||||
#if !defined(CUDA_LOG_DEBUG)
|
||||
#ifdef DEBUG
|
||||
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
|
||||
#else
|
||||
#define CUDA_LOG_DEBUG(format, ...)
|
||||
#endif
|
||||
#ifdef DEBUG
|
||||
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
|
||||
#else
|
||||
#define CUDA_LOG_DEBUG(format, ...)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) along with the supplied source context.
|
||||
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code)
|
||||
* along with the supplied source context.
|
||||
*
|
||||
* \return The CUDA error.
|
||||
*/
|
||||
__host__ __device__ inline cudaError_t cuda_perror_impl(
|
||||
cudaError_t error,
|
||||
const char* filename,
|
||||
int line)
|
||||
{
|
||||
(void)filename;
|
||||
(void)line;
|
||||
if (error)
|
||||
{
|
||||
__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
|
||||
const char* filename,
|
||||
int line) {
|
||||
(void)filename;
|
||||
(void)line;
|
||||
if (error) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
fprintf(stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
|
||||
fflush(stderr);
|
||||
fprintf(
|
||||
stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
|
||||
fflush(stderr);
|
||||
#else
|
||||
printf("CUDA error %d [%s, %d]\n", error, filename, line);
|
||||
printf("CUDA error %d [%s, %d]\n", error, filename, line);
|
||||
#endif
|
||||
}
|
||||
return error;
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Perror macro
|
||||
*/
|
||||
#ifndef CUDA_PERROR
|
||||
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__)
|
||||
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* \brief Perror macro with exit
|
||||
*/
|
||||
#ifndef CUDA_PERROR_EXIT
|
||||
#define CUDA_PERROR_EXIT(e) if (cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__)) { exit(1); }
|
||||
#define CUDA_PERROR_EXIT(e) \
|
||||
if (cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)) { \
|
||||
exit(1); \
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* \brief Perror macro only if DEBUG is defined
|
||||
*/
|
||||
#ifndef CUDA_PERROR_DEBUG
|
||||
#ifdef DEBUG
|
||||
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
|
||||
#else
|
||||
#define CUDA_PERROR_DEBUG(e) (e)
|
||||
#endif
|
||||
#ifdef DEBUG
|
||||
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
|
||||
#else
|
||||
#define CUDA_PERROR_DEBUG(e) (e)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,224 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Utilities for device introspection
|
||||
*/
|
||||
|
||||
#include "debug.h"
|
||||
#include "nv_std.h"
|
||||
#include "printable.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* math_operation_class_t
|
||||
*
|
||||
* Enumeration to select the appropriate math operation
|
||||
*
|
||||
* The assumption is multiple math operations may be used to compute GEMM
|
||||
* for a given selection of operand and accumulator types.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/// Math operation
|
||||
enum class math_operation_class_t
|
||||
{
|
||||
scalar, // scalar (and vector) multiply-accumulate operations
|
||||
matrix // Volta tensor operations
|
||||
};
|
||||
|
||||
/******************************************************************************
|
||||
* arch_family_t
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Enumeration of NVIDIA GPU architectural families
|
||||
*/
|
||||
struct arch_family_t
|
||||
{
|
||||
/// \brief Enumerants
|
||||
enum kind_t
|
||||
{
|
||||
Unsupported = 0,
|
||||
Kepler = 3,
|
||||
Maxwell = 5,
|
||||
Volta = 7,
|
||||
};
|
||||
|
||||
/// Enumerant value
|
||||
kind_t kind;
|
||||
|
||||
/// Default constructor
|
||||
arch_family_t() : kind(Unsupported) {}
|
||||
|
||||
/// Copy constructor
|
||||
arch_family_t(const kind_t &other_kind) : kind(other_kind) {}
|
||||
|
||||
/// Cast to kind_t
|
||||
operator kind_t() const { return kind; }
|
||||
|
||||
/// Returns the instance as a string
|
||||
__host__ __device__ inline
|
||||
char const* to_string() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case Kepler: return "Kepler";
|
||||
case Maxwell: return "Maxwell";
|
||||
case Volta: return "Volta";
|
||||
case Unsupported:
|
||||
default: return "Unsupported";
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert the formatted instance into the output stream
|
||||
void print(std::ostream& out) const { out << to_string(); }
|
||||
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Macro for architecture targeted by the current compiler pass
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#define CUTLASS_ARCH __CUDA_ARCH__
|
||||
#else
|
||||
#define CUTLASS_ARCH 0
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* Macro for architecture family targeted by the current compiler pass
|
||||
*/
|
||||
#define CUTLASS_ARCH_FAMILY \
|
||||
( \
|
||||
(CUTLASS_ARCH < 300) ? \
|
||||
arch_family_t::Unsupported : \
|
||||
(CUTLASS_ARCH < 500) ? \
|
||||
arch_family_t::Kepler : \
|
||||
(CUTLASS_ARCH < 700) ? \
|
||||
arch_family_t::Maxwell : \
|
||||
arch_family_t::Volta \
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Device introspection
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Empty kernel for querying PTX manifest metadata (e.g., version) for the current device
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void empty_kernel(void) { }
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10)
|
||||
*/
|
||||
cudaError_t ptx_version(int &version)
|
||||
{
|
||||
struct Dummy
|
||||
{
|
||||
/// Type definition of the empty_kernel kernel entry point
|
||||
typedef void (*EmptyKernelPtr)();
|
||||
|
||||
/// Force empty_kernel<void> to be generated if this class is used
|
||||
EmptyKernelPtr Empty()
|
||||
{
|
||||
return empty_kernel<void>;
|
||||
}
|
||||
};
|
||||
|
||||
cudaError_t error = cudaSuccess;
|
||||
do
|
||||
{
|
||||
cudaFuncAttributes empty_kernel_attrs;
|
||||
if (CUDA_PERROR_DEBUG(error = cudaFuncGetAttributes(&empty_kernel_attrs, empty_kernel<void>))) break;
|
||||
version = empty_kernel_attrs.ptxVersion * 10;
|
||||
}
|
||||
while (0);
|
||||
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Retrieves the SM version (major * 100 + minor * 10) for the current device
|
||||
*/
|
||||
cudaError_t get_sm_version(int &sm_version)
|
||||
{
|
||||
cudaError_t error = cudaSuccess;
|
||||
|
||||
// Get device ordinal
|
||||
int device_ordinal;
|
||||
if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal)))
|
||||
return error;
|
||||
|
||||
// Fill in SM version
|
||||
int major, minor;
|
||||
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_ordinal)))
|
||||
return error;
|
||||
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_ordinal)))
|
||||
return error;
|
||||
sm_version = major * 100 + minor * 10;
|
||||
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Retrieves the count for the current device
|
||||
*/
|
||||
cudaError_t get_sm_count(int &sm_count)
|
||||
{
|
||||
cudaError_t error = cudaSuccess;
|
||||
|
||||
// Get device ordinal
|
||||
int device_ordinal;
|
||||
if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal)))
|
||||
return error;
|
||||
|
||||
// Get SM count
|
||||
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal)))
|
||||
return error;
|
||||
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -1,492 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief I/O device intrinsics
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "nv_std.h"
|
||||
#include "math.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* io_vector
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Base aligned storage for IO vector
|
||||
*/
|
||||
template <typename value_t, int VectorItems, int AlignBytes> struct io_vector_base;
|
||||
template <typename value_t, int VectorItems> struct __align__(1) io_vector_base<value_t, VectorItems, 1> { value_t buff[VectorItems]; };
|
||||
template <typename value_t, int VectorItems> struct __align__(2) io_vector_base<value_t, VectorItems, 2> { value_t buff[VectorItems]; };
|
||||
template <typename value_t, int VectorItems> struct __align__(4) io_vector_base<value_t, VectorItems, 4> { value_t buff[VectorItems]; };
|
||||
template <typename value_t, int VectorItems> struct __align__(8) io_vector_base<value_t, VectorItems, 8> { value_t buff[VectorItems]; };
|
||||
template <typename value_t, int VectorItems> struct __align__(16) io_vector_base<value_t, VectorItems, 16> { value_t buff[VectorItems]; };
|
||||
|
||||
|
||||
/**
|
||||
* \brief Aligned vector type for coarsening data movement instructions
|
||||
*
|
||||
* Exposes the member constant \p VectorItems, the actual number of component
|
||||
* values comprising the io_vector
|
||||
*/
|
||||
template <
|
||||
typename value_t, ///< Component value type
|
||||
int MaxVectorItems, ///< Maximum allowable component values
|
||||
int MaxAlignBytes ///< Maximum allowable alignment
|
||||
= __NV_STD_MIN(16, MaxVectorItems * sizeof(value_t)),
|
||||
int AlignBytes ///< Actual alignment
|
||||
= __NV_STD_MIN(sizeof(value_t) * MaxVectorItems, MaxAlignBytes),
|
||||
int VectorItems ///< Actual number of component values
|
||||
= divide_assert<AlignBytes, sizeof(value_t)>::value,
|
||||
bool MustAlias ///< Whether we need to alias during loads/stores
|
||||
= (VectorItems > 4)>
|
||||
struct io_vector;
|
||||
|
||||
|
||||
/**
|
||||
* IO vector (specialization for VectorItems <= 4)
|
||||
*/
|
||||
template <
|
||||
typename value_t,
|
||||
int MaxVectorItems,
|
||||
int MaxAlignBytes,
|
||||
int _AlignBytes,
|
||||
int _VectorItems>
|
||||
struct io_vector <
|
||||
value_t,
|
||||
MaxVectorItems,
|
||||
MaxAlignBytes,
|
||||
_AlignBytes,
|
||||
_VectorItems,
|
||||
false>
|
||||
:
|
||||
io_vector_base<value_t, _VectorItems, _AlignBytes>
|
||||
{
|
||||
enum
|
||||
{
|
||||
VectorItems = _VectorItems,
|
||||
AlignBytes = _AlignBytes
|
||||
};
|
||||
|
||||
static_assert(is_pow2<AlignBytes>::value, "I/O vector alignment must be a power-of-two.");
|
||||
static_assert((AlignBytes <= 16), "I/O vector alignment must <= 16B.");
|
||||
|
||||
inline __device__
|
||||
void load(const io_vector *ptr)
|
||||
{
|
||||
*this = *ptr;
|
||||
}
|
||||
|
||||
inline __device__
|
||||
void load(const value_t *ptr)
|
||||
{
|
||||
*this = *reinterpret_cast<const io_vector*>(ptr);
|
||||
}
|
||||
|
||||
|
||||
inline __device__
|
||||
void store(io_vector *ptr) const
|
||||
{
|
||||
*ptr = *this;
|
||||
}
|
||||
|
||||
inline __device__
|
||||
void store(value_t *ptr) const
|
||||
{
|
||||
*reinterpret_cast<io_vector*>(ptr) = *this;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* IO vector (specialization for VectorItems > 4)
|
||||
*
|
||||
* NB: Workaround for NVCC not generating 128-bit loads/stores for aligned
|
||||
* structures having component types < 32b
|
||||
*/
|
||||
template <
|
||||
typename value_t,
|
||||
int MaxVectorItems,
|
||||
int MaxAlignBytes,
|
||||
int _AlignBytes,
|
||||
int _VectorItems>
|
||||
struct io_vector <
|
||||
value_t,
|
||||
MaxVectorItems,
|
||||
MaxAlignBytes,
|
||||
_AlignBytes,
|
||||
_VectorItems,
|
||||
true>
|
||||
:
|
||||
io_vector_base<value_t, _VectorItems, _AlignBytes>
|
||||
{
|
||||
enum
|
||||
{
|
||||
VectorItems = _VectorItems,
|
||||
AlignBytes = _AlignBytes
|
||||
};
|
||||
|
||||
static_assert(is_pow2<AlignBytes>::value, "I/O vector alignment must be a power-of-two.");
|
||||
static_assert((AlignBytes <= 16), "I/O vector alignment must <= 16B.");
|
||||
|
||||
typedef typename nv_std::conditional<(AlignBytes == 8),
|
||||
uint2, // Use 8B load
|
||||
uint4> // Use 16B load
|
||||
::type align_t;
|
||||
|
||||
inline __device__
|
||||
void load(const io_vector *ptr)
|
||||
{
|
||||
*reinterpret_cast<align_t*>(this) = *reinterpret_cast<const align_t*>(ptr);
|
||||
}
|
||||
|
||||
inline __device__
|
||||
void load(const value_t *ptr)
|
||||
{
|
||||
*reinterpret_cast<align_t*>(this) = *reinterpret_cast<const align_t*>(ptr);
|
||||
}
|
||||
|
||||
|
||||
inline __device__
|
||||
void store(io_vector *ptr) const
|
||||
{
|
||||
*reinterpret_cast<align_t*>(ptr) = *reinterpret_cast<const align_t*>(this);
|
||||
}
|
||||
|
||||
inline __device__
|
||||
void store(value_t *ptr) const
|
||||
{
|
||||
*reinterpret_cast<align_t*>(ptr) = *reinterpret_cast<const align_t*>(this);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Macro expansions for vector loads
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Define vector-4 LD specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_LD_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
value_t (&dest)[4], \
|
||||
ptr_t ptr) \
|
||||
{ \
|
||||
asm volatile ("ld."#load_modifier".v4."#ptx_type" {%0, %1, %2, %3}, [%4];\n" \
|
||||
: \
|
||||
"="#val_constraint(dest[0]), \
|
||||
"="#val_constraint(dest[1]), \
|
||||
"="#val_constraint(dest[2]), \
|
||||
"="#val_constraint(dest[3]) \
|
||||
: \
|
||||
#ptr_constraint(ptr)); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Define vector-2 LD specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_LD_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
value_t (&dest)[2], \
|
||||
ptr_t ptr) \
|
||||
{ \
|
||||
asm volatile ("ld."#load_modifier".v2."#ptx_type" {%0, %1}, [%2];\n" \
|
||||
: \
|
||||
"="#val_constraint(dest[0]), \
|
||||
"="#val_constraint(dest[1]) \
|
||||
: \
|
||||
#ptr_constraint(ptr)); \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Define vector-1 LD specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_LD_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
value_t (&dest)[1], \
|
||||
ptr_t ptr) \
|
||||
{ \
|
||||
asm volatile ("ld."#load_modifier"."#ptx_type" %0, [%1];\n" \
|
||||
: \
|
||||
"="#val_constraint(dest[0]) \
|
||||
: \
|
||||
#ptr_constraint(ptr)); \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Define powers-of-two vector LD specializations
|
||||
*/
|
||||
#define CUTLASS_LD_ALL(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_LD_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_LD_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_LD_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint)
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Macro expansions for vector stores
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Define vector-4 ST specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_ST_V4(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
ptr_t ptr, \
|
||||
const value_t (&src)[4]) \
|
||||
{ \
|
||||
asm volatile ("st."#store_modifier".v4."#ptx_type" [%0], {%1, %2, %3, %4};\n" \
|
||||
: : \
|
||||
#ptr_constraint(ptr), \
|
||||
#val_constraint(src[0]), \
|
||||
#val_constraint(src[1]), \
|
||||
#val_constraint(src[2]), \
|
||||
#val_constraint(src[3])); \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Define vector-2 ST specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_ST_V2(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
ptr_t ptr, \
|
||||
const value_t (&src)[2]) \
|
||||
{ \
|
||||
asm volatile ("st."#store_modifier".v2."#ptx_type" [%0], {%1, %2};\n" \
|
||||
: : \
|
||||
#ptr_constraint(ptr), \
|
||||
#val_constraint(src[0]), \
|
||||
#val_constraint(src[1])); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Define vector-1 ST specialization for the given load modifier
|
||||
*/
|
||||
#define CUTLASS_ST_V1(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
template <typename ptr_t> \
|
||||
inline __device__ \
|
||||
void f_name( \
|
||||
ptr_t ptr, \
|
||||
const value_t (&src)[1]) \
|
||||
{ \
|
||||
asm volatile ("st."#store_modifier"."#ptx_type" [%0], %1;\n" \
|
||||
: : \
|
||||
#ptr_constraint(ptr), \
|
||||
#val_constraint(src[0])); \
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Define powers-of-two vector LD specializations
|
||||
*/
|
||||
#define CUTLASS_ST_ALL(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_ST_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_ST_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
|
||||
CUTLASS_ST_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint)
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Macro expansions for vector IO
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Define global and shared LD specializations
|
||||
*/
|
||||
#define CUTLASS_IO(value_t, ptx_type, val_constraint) \
|
||||
CUTLASS_LD_ALL(ldg_cg_internal, value_t, global.cg, ptx_type, val_constraint, l) \
|
||||
CUTLASS_ST_ALL(stg_cg_internal, value_t, global.cg, ptx_type, val_constraint, l)
|
||||
|
||||
|
||||
// Define IO for useful types
|
||||
CUTLASS_IO(double, f64, d)
|
||||
CUTLASS_IO(float, f32, f)
|
||||
CUTLASS_IO(int64_t, b64, l)
|
||||
CUTLASS_IO(int32_t, b32, r)
|
||||
CUTLASS_IO(int16_t, b16, h)
|
||||
|
||||
|
||||
// Macro cleanup
|
||||
#undef CUTLASS_IO
|
||||
#undef CUTLASS_LD_ALL
|
||||
#undef CUTLASS_LD_V4
|
||||
#undef CUTLASS_LD_V2
|
||||
#undef CUTLASS_LD_V1
|
||||
#undef CUTLASS_ST_ALL
|
||||
#undef CUTLASS_ST_V4
|
||||
#undef CUTLASS_ST_V2
|
||||
#undef CUTLASS_ST_V1
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* I/O cast types
|
||||
******************************************************************************/
|
||||
|
||||
/// Provides the type for which to reinterpret-cast a given vector
|
||||
template <
|
||||
typename value_t,
|
||||
int IoVecDim,
|
||||
int ValueBytes = sizeof(value_t)>
|
||||
struct io_cast
|
||||
{
|
||||
typedef value_t type[IoVecDim];
|
||||
};
|
||||
|
||||
|
||||
/// Provides the type for which to reinterpret-cast a vector of 1B types
|
||||
template <
|
||||
typename value_t,
|
||||
int IoVecDim>
|
||||
struct io_cast<value_t, IoVecDim, 1>
|
||||
{
|
||||
typedef typename nv_std::conditional<
|
||||
(IoVecDim < 2),
|
||||
int8_t[1], // Use 8b load
|
||||
typename nv_std::conditional<
|
||||
(IoVecDim < 4),
|
||||
int16_t[1], // Use 16b load
|
||||
int32_t[IoVecDim / 4]>::type>::type // Use up to 128b load
|
||||
type;
|
||||
};
|
||||
|
||||
|
||||
/// Provides the type for which to reinterpret-cast a vector of 2B types
|
||||
template <
|
||||
typename value_t,
|
||||
int IoVecDim>
|
||||
struct io_cast<value_t, IoVecDim, 2>
|
||||
{
|
||||
typedef typename nv_std::conditional<
|
||||
(IoVecDim < 2),
|
||||
int16_t[1], // Use 16b load
|
||||
int32_t[IoVecDim / 2]>::type // Use up to 128b load
|
||||
type;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* ldg_cg intrinsics
|
||||
******************************************************************************/
|
||||
|
||||
/// Load from global (cache-global modifier)
|
||||
template <typename value_t, typename ptr_t>
|
||||
inline __device__
|
||||
void ldg_cg(
|
||||
value_t &dest,
|
||||
ptr_t d_in)
|
||||
{
|
||||
// Cast dest to a different array type if necessary
|
||||
ldg_cg_internal(
|
||||
reinterpret_cast<typename io_cast<value_t, 1>::type &>(dest),
|
||||
d_in);
|
||||
}
|
||||
|
||||
/// Load from global (cache-global modifier)
|
||||
template <typename value_t, int IoVecDim, typename ptr_t>
|
||||
inline __device__
|
||||
void ldg_cg(
|
||||
value_t (&dest)[IoVecDim],
|
||||
ptr_t d_in)
|
||||
{
|
||||
static_assert(is_pow2<IoVecDim>::value, "I/O vectors must be a power-of-two.");
|
||||
|
||||
// Cast dest to a different array type if necessary
|
||||
ldg_cg_internal(
|
||||
reinterpret_cast<typename io_cast<value_t, IoVecDim>::type &>(dest),
|
||||
d_in);
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* stg_cg intrinsics
|
||||
******************************************************************************/
|
||||
|
||||
/// Store to global (cache-global modifier)
|
||||
template <typename ptr_t, typename value_t>
|
||||
inline __device__
|
||||
void stg_cg(
|
||||
ptr_t dest,
|
||||
const value_t &src)
|
||||
{
|
||||
// Cast src to a different array type if necessary
|
||||
stg_cg_internal(
|
||||
dest,
|
||||
reinterpret_cast<const typename io_cast<value_t, 1>::type &>(src));
|
||||
}
|
||||
|
||||
/// Store to global (cache-global modifier)
|
||||
template <typename ptr_t, int IoVecDim, typename value_t>
|
||||
inline __device__
|
||||
void stg_cg(
|
||||
ptr_t dest,
|
||||
const value_t (&src)[IoVecDim])
|
||||
{
|
||||
static_assert(is_pow2<IoVecDim>::value, "I/O vectors must be a power-of-two.");
|
||||
|
||||
// Cast src to a different array type if necessary
|
||||
stg_cg_internal(
|
||||
dest,
|
||||
reinterpret_cast<const typename io_cast<value_t, IoVecDim>::type &>(src));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,167 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Math utilities
|
||||
*/
|
||||
|
||||
#include "nv_std.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Static math utilities
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Statically determine if N is a power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct is_pow2 : nv_std::integral_constant<bool, (N & (N - 1)) == 0>
|
||||
{};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded down
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_down
|
||||
{
|
||||
/// Static logarithm value
|
||||
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_down<N, 1, Count>
|
||||
{
|
||||
enum { value = Count };
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Statically determine log2(N), rounded up
|
||||
*/
|
||||
template <int N, int CurrentVal = N, int Count = 0>
|
||||
struct log2_up
|
||||
{
|
||||
/// Static logarithm value
|
||||
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
|
||||
};
|
||||
|
||||
// Base case
|
||||
template <int N, int Count>
|
||||
struct log2_up<N, 1, Count>
|
||||
{
|
||||
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
|
||||
};
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Statically estimate sqrt(N) to the nearest power-of-two
|
||||
*/
|
||||
template <int N>
|
||||
struct sqrt_est
|
||||
{
|
||||
enum { value = 1 << (log2_up<N>::value / 2) };
|
||||
};
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* For performing a constant-division with a compile-time assertion that the
|
||||
* Divisor evenly-divides the Dividend.
|
||||
*/
|
||||
template <int Dividend, int Divisor>
|
||||
struct divide_assert
|
||||
{
|
||||
enum { value = Dividend / Divisor};
|
||||
|
||||
static_assert((Dividend % Divisor == 0), "Not an even multiple");
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Rounding
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Round dividend up to the nearest multiple of divisor
|
||||
*/
|
||||
template <typename dividend_t, typename divisor_t>
|
||||
inline __host__ __device__
|
||||
dividend_t round_nearest(dividend_t dividend, divisor_t divisor)
|
||||
{
|
||||
return ((dividend + divisor - 1) / divisor) * divisor;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Greatest common divisor
|
||||
*/
|
||||
template <typename value_t>
|
||||
inline __host__ __device__
|
||||
value_t gcd(value_t a, value_t b)
|
||||
{
|
||||
for (;;)
|
||||
{
|
||||
if (a == 0) return b;
|
||||
b %= a;
|
||||
if (b == 0) return a;
|
||||
a %= b;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Least common multiple
|
||||
*/
|
||||
template <typename value_t>
|
||||
inline __host__ __device__
|
||||
value_t lcm(value_t a, value_t b)
|
||||
{
|
||||
value_t temp = gcd(a, b);
|
||||
|
||||
return temp ? (a / temp * b) : 0;
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,102 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Enumeration of dense matrix view transformations
|
||||
*/
|
||||
|
||||
#include "printable.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* matrix_transform_t
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief Enumeration of dense matrix view transformations
|
||||
*
|
||||
* These enumerators (and corresponding tag types) describe which view
|
||||
* transformation needs to be applied prior to operation upon a given dense
|
||||
* matrix. Its values correspond to Fortran characters 'n' (non-transpose),
|
||||
* 't'(transpose) and 'c'(conjugate transpose) that are often
|
||||
* used as parameters to legacy BLAS implementations
|
||||
*/
|
||||
struct matrix_transform_t : printable_t
|
||||
{
|
||||
/// \brief Enumerants (same as CUBLAS)
|
||||
enum kind_t
|
||||
{
|
||||
/// Invalid view
|
||||
Invalid = -1,
|
||||
|
||||
/// Non-transpose view
|
||||
NonTranspose = 0,
|
||||
|
||||
/// Transpose view
|
||||
Transpose = 1,
|
||||
|
||||
/// Conjugate transpose view
|
||||
ConjugateTranpose = 2,
|
||||
};
|
||||
|
||||
/// Enumerant value
|
||||
kind_t kind;
|
||||
|
||||
/// Default constructor
|
||||
matrix_transform_t() : kind(Invalid) {}
|
||||
|
||||
/// Copy constructor
|
||||
matrix_transform_t(const kind_t &other_kind) : kind(other_kind) {}
|
||||
|
||||
/// Cast to kind_t
|
||||
operator kind_t() const { return kind; }
|
||||
|
||||
/// Returns the instance as a string
|
||||
__host__ __device__ inline
|
||||
char const* to_string() const
|
||||
{
|
||||
switch (kind)
|
||||
{
|
||||
case NonTranspose: return "NonTranspose";
|
||||
case Transpose: return "Transpose";
|
||||
case ConjugateTranpose: return "ConjugateTranpose";
|
||||
default: return "Invalid";
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert the formatted instance into the output stream
|
||||
void print(std::ostream& out) const { out << to_string(); }
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,705 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++ features that may be otherwise unimplemented for CUDA device functions.
|
||||
*
|
||||
* This file has three components:
|
||||
*
|
||||
* (1) Macros:
|
||||
* - Empty macro defines for C++ keywords not supported by the current
|
||||
* version of C++. These simply allow compilation to proceed (but do
|
||||
* not provide the added semantics).
|
||||
* - \p noexcept
|
||||
* - \p constexpr
|
||||
* - \p nullptr
|
||||
* - \p static_assert
|
||||
*
|
||||
* - Macro functions that we need in constant expressions because the
|
||||
* C++ equivalents require constexpr compiler support. These are
|
||||
* prefixed with \p __NV_STD_*
|
||||
* - \p __NV_STD_MAX
|
||||
* - \p __NV_STD_MIN
|
||||
*
|
||||
* (2) Re-implementations of STL functions and types:
|
||||
* - C++ features that need the \p __device__ annotation. These are
|
||||
* placed into the \p nv_std namespace.
|
||||
* - \p plus
|
||||
* - \p less
|
||||
* - \p greater
|
||||
* - \p min
|
||||
* - \p max
|
||||
* - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair())
|
||||
*
|
||||
* (3) Stop-gap implementations of unsupported STL functions and types:
|
||||
* - STL functions and types defined by C++ 11/14/17/etc. that are not
|
||||
* provided by the current version of C++. These are placed into the
|
||||
* \p nv_std namespace
|
||||
* - \p integral_constant
|
||||
* - \p nullptr_t
|
||||
* - \p true_type
|
||||
* - \p false_type
|
||||
* - \p bool_constant
|
||||
* - \p enable_if
|
||||
* - \p conditional
|
||||
* - \p is_same
|
||||
* - \p is_base_of
|
||||
* - \p remove_const
|
||||
* - \p remove_volatile
|
||||
* - \p remove_cv
|
||||
* - \p is_volatile
|
||||
* - \p is_pointer
|
||||
* - \p is_void
|
||||
* - \p is_integral
|
||||
* - \p is_floating_point
|
||||
* - \p is_arithmetic
|
||||
* - \p is_fundamental
|
||||
* - \p is_trivially_copyable
|
||||
* - \p alignment_of
|
||||
* - \p aligned_storage
|
||||
*
|
||||
* (4) Functions and types that are STL-like (but aren't in the STL):
|
||||
* - \p TODO: min and max functors?
|
||||
*
|
||||
* The idea is that, as we drop support for older compilers, we can simply #define
|
||||
* the \p __NV_STD_XYZ macros and \p nv_std namespace to alias their C++
|
||||
* counterparts (or trivially find-and-replace their occurrences in code text).
|
||||
*/
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Include STL files that nv_std provides functionality for
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#include <cstddef> // nullptr_t
|
||||
#include <algorithm> // Minimum/maximum operations
|
||||
#include <functional> // Arithmetic operations
|
||||
#include <utility> // For methods on std::pair
|
||||
#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500))
|
||||
#include <type_traits> // For integral constants, conditional metaprogramming, and type traits
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Macros
|
||||
******************************************************************************/
|
||||
//-----------------------------------------------------------------------------
|
||||
// Keywords
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// noexcept, constexpr
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
#ifndef noexcept
|
||||
#define noexcept
|
||||
#endif
|
||||
#ifndef constexpr
|
||||
#define constexpr
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// nullptr
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310 ))
|
||||
#ifndef nullptr
|
||||
#define nullptr 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// static_assert
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600 ))
|
||||
#ifndef static_assert
|
||||
#define __nv_std_cat_(a, b) a ## b
|
||||
#define __nv_std_cat(a, b) __nv_std_cat_(a, b)
|
||||
#define static_assert(__e, __m) typedef int __nv_std_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1]
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Functions
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// Select maximum(a, b)
|
||||
#ifndef __NV_STD_MAX
|
||||
#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a))
|
||||
#endif
|
||||
|
||||
/// Select minimum(a, b)
|
||||
#ifndef __NV_STD_MIN
|
||||
#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a))
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Re-implementations
|
||||
******************************************************************************/
|
||||
|
||||
namespace nv_std {
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Arithmetic operations, comparisons <functional>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// nv_std::plus
|
||||
template <typename T>
|
||||
struct plus
|
||||
{
|
||||
inline __host__ __device__
|
||||
constexpr T operator()(const T &lhs, const T &rhs) const
|
||||
{
|
||||
return lhs + rhs;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// std::less
|
||||
template <typename T>
|
||||
struct less
|
||||
{
|
||||
inline __host__ __device__
|
||||
constexpr bool operator()(const T &lhs, const T &rhs) const
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// std::greater
|
||||
template <typename T>
|
||||
struct greater
|
||||
{
|
||||
inline __host__ __device__
|
||||
constexpr bool operator()(const T &lhs, const T &rhs) const
|
||||
{
|
||||
return lhs > rhs;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Minimum/maximum operations <algorithm>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// std::min
|
||||
template <typename T>
|
||||
inline __host__ __device__
|
||||
constexpr const T& min(
|
||||
const T& a,
|
||||
const T& b)
|
||||
{
|
||||
return (b < a) ? b : a;
|
||||
}
|
||||
|
||||
/// std::max
|
||||
template <typename T>
|
||||
inline __host__ __device__
|
||||
constexpr const T& max(
|
||||
const T& a,
|
||||
const T& b)
|
||||
{
|
||||
return (a < b) ? b : a;
|
||||
}
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Methods on std::pair
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
using std::pair;
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator==( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return (lhs.first == rhs.first) && (lhs.second == rhs.second);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator!=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return (lhs.first != rhs.first) && (lhs.second != rhs.second);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator<( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return (lhs.first < rhs.first) ?
|
||||
true :
|
||||
(rhs.first < lhs.first) ?
|
||||
false :
|
||||
(lhs.second < rhs.second);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator<=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return !(rhs < lhs);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator>( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return (rhs < lhs);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
constexpr bool operator>=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
|
||||
{
|
||||
return !(lhs < rhs);
|
||||
}
|
||||
|
||||
template< class T1, class T2 >
|
||||
inline __host__ __device__
|
||||
std::pair<T1,T2> make_pair( T1 t, T2 u )
|
||||
{
|
||||
std::pair<T1,T2> retval;
|
||||
retval.first = t;
|
||||
retval.second = u;
|
||||
return retval;
|
||||
}
|
||||
|
||||
} // namespace nv_std
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Implementations of C++ 11/14/17/... STL features
|
||||
******************************************************************************/
|
||||
|
||||
namespace nv_std {
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Integral constant helper types <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::integral_constant
|
||||
template <typename value_t, value_t V>
|
||||
struct integral_constant;
|
||||
|
||||
/// std::integral_constant
|
||||
template <typename value_t, value_t V>
|
||||
struct integral_constant
|
||||
{
|
||||
static const value_t value = V;
|
||||
|
||||
typedef value_t value_type;
|
||||
typedef integral_constant<value_t, V> type;
|
||||
|
||||
inline __host__ __device__ operator value_type() const
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
inline __host__ __device__ const value_type operator()() const
|
||||
{
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
#else
|
||||
|
||||
using std::integral_constant;
|
||||
using std::pair;
|
||||
|
||||
#endif
|
||||
|
||||
/// The type used as a compile-time boolean with true value.
|
||||
typedef integral_constant<bool, true> true_type;
|
||||
|
||||
/// The type used as a compile-time boolean with false value.
|
||||
typedef integral_constant<bool, false> false_type;
|
||||
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
|
||||
/// std::bool_constant
|
||||
template <bool V>
|
||||
struct bool_constant : nv_std::integral_constant<bool, V>
|
||||
{};
|
||||
|
||||
#else
|
||||
|
||||
using std::bool_constant;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700))
|
||||
|
||||
/// std::nullptr_t
|
||||
struct nullptr_t {};
|
||||
|
||||
#else
|
||||
|
||||
using std::nullptr_t;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Conditional metaprogramming <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600))
|
||||
|
||||
/// std::enable_if (true specialization)
|
||||
template<bool C, typename T = void>
|
||||
struct enable_if {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::enable_if (false specialization)
|
||||
template<typename T>
|
||||
struct enable_if<false, T> { };
|
||||
|
||||
|
||||
/// std::conditional (true specialization)
|
||||
template<bool B, class T, class F>
|
||||
struct conditional { typedef T type; };
|
||||
|
||||
/// std::conditional (false specialization)
|
||||
template<class T, class F>
|
||||
struct conditional<false, T, F> { typedef F type; };
|
||||
|
||||
#else
|
||||
|
||||
using std::enable_if;
|
||||
using std::conditional;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Const/volatility specifiers <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::remove_const (non-const specialization)
|
||||
template <typename T> struct remove_const { typedef T type; };
|
||||
|
||||
/// std::remove_const (const specialization)
|
||||
template <typename T> struct remove_const<const T> { typedef T type; };
|
||||
|
||||
|
||||
|
||||
/// std::remove_volatile (non-volatile specialization)
|
||||
template <typename T> struct remove_volatile { typedef T type; };
|
||||
|
||||
/// std::remove_volatile (volatile specialization)
|
||||
template <typename T> struct remove_volatile<volatile T> { typedef T type; };
|
||||
|
||||
|
||||
|
||||
/// std::remove_cv
|
||||
template <typename T>
|
||||
struct remove_cv {
|
||||
typedef typename remove_volatile<typename remove_const<T>::type>::type type;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::remove_const;
|
||||
using std::remove_volatile;
|
||||
using std::remove_cv;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Type relationships <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::is_same (false specialization)
|
||||
template <typename A, typename B>
|
||||
struct is_same : false_type
|
||||
{};
|
||||
|
||||
/// std::is_same (true specialization)
|
||||
template <typename A>
|
||||
struct is_same<A, A> : true_type
|
||||
{};
|
||||
|
||||
|
||||
/// Helper for std::is_base_of
|
||||
template<typename BaseT, typename DerivedT>
|
||||
struct is_base_of_helper
|
||||
{
|
||||
typedef char (&yes)[1];
|
||||
typedef char (&no)[2];
|
||||
|
||||
template<typename B, typename D>
|
||||
struct dummy
|
||||
{
|
||||
operator B*() const;
|
||||
operator D*();
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static yes check(DerivedT*, T);
|
||||
|
||||
static no check(BaseT*, int);
|
||||
|
||||
static const bool value = sizeof(check(dummy<BaseT, DerivedT>(), int())) == sizeof(yes);
|
||||
};
|
||||
|
||||
/// std::is_base_of
|
||||
template <typename BaseT, typename DerivedT>
|
||||
struct is_base_of : integral_constant<
|
||||
bool,
|
||||
(is_base_of_helper<typename remove_cv<BaseT>::type, typename remove_cv<DerivedT>::type>::value) ||
|
||||
(is_same<typename remove_cv<BaseT>::type, typename remove_cv<DerivedT>::type>::value)>
|
||||
{};
|
||||
|
||||
|
||||
#else
|
||||
|
||||
using std::is_same;
|
||||
using std::is_base_of;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Type properties <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::is_volatile
|
||||
template <typename T> struct is_volatile : false_type {};
|
||||
template <typename T> struct is_volatile<volatile T> : true_type {};
|
||||
|
||||
|
||||
/// Helper for std::is_pointer (false specialization)
|
||||
template <typename T> struct is_pointer_helper : false_type {};
|
||||
|
||||
/// Helper for std::is_pointer (true specialization)
|
||||
template <typename T> struct is_pointer_helper<T*> : true_type {};
|
||||
|
||||
/// std::is_pointer
|
||||
template <typename T> struct is_pointer : is_pointer_helper<typename remove_cv<T>::type> {};
|
||||
|
||||
|
||||
|
||||
/// std::is_void
|
||||
template <typename T>
|
||||
struct is_void : is_same<void, typename remove_cv<T>::type>
|
||||
{};
|
||||
|
||||
|
||||
|
||||
/// std::is_integral
|
||||
template <typename T> struct is_integral : false_type {};
|
||||
template <> struct is_integral<char> : true_type {};
|
||||
template <> struct is_integral<signed char> : true_type {};
|
||||
template <> struct is_integral<unsigned char> : true_type {};
|
||||
template <> struct is_integral<short> : true_type {};
|
||||
template <> struct is_integral<unsigned short> : true_type {};
|
||||
template <> struct is_integral<int> : true_type {};
|
||||
template <> struct is_integral<unsigned int> : true_type {};
|
||||
template <> struct is_integral<long> : true_type {};
|
||||
template <> struct is_integral<unsigned long> : true_type {};
|
||||
template <> struct is_integral<long long> : true_type {};
|
||||
template <> struct is_integral<unsigned long long> : true_type {};
|
||||
template <typename T> struct is_integral<volatile T> : is_integral<T> {};
|
||||
template <typename T> struct is_integral<const T> : is_integral<T> {};
|
||||
template <typename T> struct is_integral<const volatile T> : is_integral<T> {};
|
||||
|
||||
|
||||
|
||||
/// std::is_floating_point
|
||||
template <typename T>
|
||||
struct is_floating_point : integral_constant<
|
||||
bool,
|
||||
(is_same<float, typename remove_cv<T>::type>::value ||
|
||||
is_same<double, typename remove_cv<T>::type>::value)>
|
||||
{};
|
||||
|
||||
|
||||
|
||||
/// std::is_arithmetic
|
||||
template <typename T>
|
||||
struct is_arithmetic :
|
||||
integral_constant<bool, (is_integral<T>::value || is_floating_point<T>::value)>
|
||||
{};
|
||||
|
||||
|
||||
/// std::is_fundamental
|
||||
template <typename T>
|
||||
struct is_fundamental : integral_constant<
|
||||
bool, (is_arithmetic<T>::value ||
|
||||
is_void<T>::value ||
|
||||
is_same<nullptr_t, typename remove_cv<T>::type>::value)>
|
||||
{};
|
||||
|
||||
|
||||
|
||||
|
||||
#else
|
||||
|
||||
using std::is_volatile;
|
||||
using std::is_pointer;
|
||||
using std::is_void;
|
||||
using std::is_integral;
|
||||
using std::is_floating_point;
|
||||
using std::is_arithmetic;
|
||||
using std::is_fundamental;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || \
|
||||
(defined(_MSC_VER) && (_MSC_VER < 1800)) || \
|
||||
(defined(__GNUG__) && (__GNUC__ < 5))
|
||||
|
||||
/**
|
||||
* std::is_trivially_copyable
|
||||
*
|
||||
* This implementation only evaluates true if T is fundamental or pointer
|
||||
*
|
||||
* Without help from partial template specializations provided by the user for
|
||||
* a specific class or struct, this trait will never report that the specified
|
||||
* class or struct is trivially-copyable ; this is always safe,
|
||||
* if possibly sub-optimal.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_trivially_copyable :
|
||||
integral_constant<bool, (is_fundamental<T>::value || is_pointer<T>::value)>
|
||||
{};
|
||||
|
||||
#else
|
||||
|
||||
using std::is_trivially_copyable;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Alignment and layout utilities
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
|
||||
/// std::alignment_of
|
||||
template <typename value_t>
|
||||
struct alignment_of
|
||||
{
|
||||
struct pad
|
||||
{
|
||||
value_t val;
|
||||
char byte;
|
||||
};
|
||||
|
||||
enum
|
||||
{
|
||||
value = sizeof(pad) - sizeof(value_t)
|
||||
};
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename value_t>
|
||||
struct alignment_of : std::alignment_of<value_t> {};
|
||||
|
||||
#endif
|
||||
|
||||
/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */
|
||||
template <> struct alignment_of<int4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<uint4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<float4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<long4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<ulong4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<longlong2> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<ulonglong2> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<double2> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<longlong4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<ulonglong4> { enum { value = 16 }; };
|
||||
template <> struct alignment_of<double4> { enum { value = 16 }; };
|
||||
|
||||
// Specializations for volatile/const qualified types
|
||||
template <typename value_t> struct alignment_of<volatile value_t> : alignment_of<value_t> {};
|
||||
template <typename value_t> struct alignment_of<const value_t> : alignment_of<value_t> {};
|
||||
template <typename value_t> struct alignment_of<const volatile value_t> : alignment_of<value_t> {};
|
||||
|
||||
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800))
|
||||
|
||||
template<size_t Align> struct aligned_chunk;
|
||||
template<> struct __align__(1) aligned_chunk<1> { uint8_t buff; };
|
||||
template<> struct __align__(2) aligned_chunk<2> { uint16_t buff; };
|
||||
template<> struct __align__(4) aligned_chunk<4> { uint32_t buff; };
|
||||
template<> struct __align__(8) aligned_chunk<8> { uint32_t buff[2]; };
|
||||
template<> struct __align__(16) aligned_chunk<16> { uint32_t buff[4]; };
|
||||
template<> struct __align__(32) aligned_chunk<32> { uint32_t buff[8]; };
|
||||
template<> struct __align__(64) aligned_chunk<64> { uint32_t buff[16]; };
|
||||
template<> struct __align__(128) aligned_chunk<128> { uint32_t buff[32]; };
|
||||
template<> struct __align__(256) aligned_chunk<256> { uint32_t buff[64]; };
|
||||
template<> struct __align__(512) aligned_chunk<512> { uint32_t buff[128]; };
|
||||
template<> struct __align__(1024) aligned_chunk<1024> { uint32_t buff[256]; };
|
||||
template<> struct __align__(2048) aligned_chunk<2048> { uint32_t buff[512]; };
|
||||
template<> struct __align__(4096) aligned_chunk<4096> { uint32_t buff[1024]; };
|
||||
|
||||
/// std::aligned_storage
|
||||
template <size_t Len, size_t Align>
|
||||
struct aligned_storage
|
||||
{
|
||||
typedef aligned_chunk<Align> type[Len / sizeof(aligned_chunk<Align>)];
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::aligned_storage;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
}; // namespace nv_std
|
||||
|
||||
801
cutlass/util/platform.h
Normal file
801
cutlass/util/platform.h
Normal file
@ -0,0 +1,801 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++ features that may be otherwise unimplemented for CUDA device functions.
|
||||
*
|
||||
* This file has three components:
|
||||
*
|
||||
* (1) Macros:
|
||||
* - Empty macro defines for C++ keywords not supported by the current
|
||||
* version of C++. These simply allow compilation to proceed (but do
|
||||
* not provide the added semantics).
|
||||
* - \p noexcept
|
||||
* - \p constexpr
|
||||
* - \p nullptr
|
||||
* - \p static_assert
|
||||
*
|
||||
* - Macro functions that we need in constant expressions because the
|
||||
* C++ equivalents require constexpr compiler support. These are
|
||||
* prefixed with \p __NV_STD_*
|
||||
* - \p __NV_STD_MAX
|
||||
* - \p __NV_STD_MIN
|
||||
*
|
||||
* (2) Re-implementations of STL functions and types:
|
||||
* - C++ features that need the \p __device__ annotation. These are
|
||||
* placed into the \p platform namespace.
|
||||
* - \p plus
|
||||
* - \p less
|
||||
* - \p greater
|
||||
* - \p min
|
||||
* - \p max
|
||||
* - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair())
|
||||
*
|
||||
* (3) Stop-gap implementations of unsupported STL functions and types:
|
||||
* - STL functions and types defined by C++ 11/14/17/etc. that are not
|
||||
* provided by the current version of C++. These are placed into the
|
||||
* \p platform namespace
|
||||
* - \p integral_constant
|
||||
* - \p nullptr_t
|
||||
* - \p true_type
|
||||
* - \p false_type
|
||||
* - \p bool_constant
|
||||
* - \p enable_if
|
||||
* - \p conditional
|
||||
* - \p is_same
|
||||
* - \p is_base_of
|
||||
* - \p remove_const
|
||||
* - \p remove_volatile
|
||||
* - \p remove_cv
|
||||
* - \p is_volatile
|
||||
* - \p is_pointer
|
||||
* - \p is_void
|
||||
* - \p is_integral
|
||||
* - \p is_floating_point
|
||||
* - \p is_arithmetic
|
||||
* - \p is_fundamental
|
||||
* - \p is_trivially_copyable
|
||||
* - \p alignment_of
|
||||
* - \p aligned_storage
|
||||
*
|
||||
* (4) Functions and types that are STL-like (but aren't in the STL):
|
||||
* - \p TODO: min and max functors?
|
||||
*
|
||||
* The idea is that, as we drop support for older compilers, we can simply #define
|
||||
* the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++
|
||||
* counterparts (or trivially find-and-replace their occurrences in code text).
|
||||
*/
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Dependencies
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
//-----------------------------------------------------------------------------
|
||||
// Include STL files that platform provides functionality for
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#include <algorithm> // Minimum/maximum operations
|
||||
#include <cstddef> // nullptr_t
|
||||
#include <functional> // Arithmetic operations
|
||||
#include <utility> // For methods on std::pair
|
||||
#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500))
|
||||
#include <type_traits> // For integral constants, conditional metaprogramming, and type traits
|
||||
#endif
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
#endif
|
||||
/******************************************************************************
|
||||
* Macros
|
||||
******************************************************************************/
|
||||
//-----------------------------------------------------------------------------
|
||||
// Keywords
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// noexcept, constexpr
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
#ifndef noexcept
|
||||
#define noexcept
|
||||
#endif
|
||||
#ifndef constexpr
|
||||
#define constexpr
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// nullptr
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310))
|
||||
#ifndef nullptr
|
||||
#define nullptr 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// static_assert
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600))
|
||||
#ifndef static_assert
|
||||
#define __platform_cat_(a, b) a##b
|
||||
#define __platform_cat(a, b) __platform_cat_(a, b)
|
||||
#define static_assert(__e, __m) typedef int __platform_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1]
|
||||
#endif
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Functions
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// Select maximum(a, b)
|
||||
#ifndef __NV_STD_MAX
|
||||
#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a))
|
||||
#endif
|
||||
|
||||
/// Select minimum(a, b)
|
||||
#ifndef __NV_STD_MIN
|
||||
#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a))
|
||||
#endif
|
||||
|
||||
/******************************************************************************
|
||||
* Re-implementations
|
||||
******************************************************************************/
|
||||
namespace cutlass {
|
||||
namespace platform {
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Arithmetic operations, comparisons <functional>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// platform::plus
|
||||
template <typename T>
|
||||
struct plus {
|
||||
CUTLASS_HOST_DEVICE constexpr T operator()(const T& lhs, const T& rhs) const { return lhs + rhs; }
|
||||
};
|
||||
|
||||
/// std::less
|
||||
template <typename T>
|
||||
struct less {
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator()(const T& lhs, const T& rhs) const {
|
||||
return lhs < rhs;
|
||||
}
|
||||
};
|
||||
|
||||
/// std::greater
|
||||
template <typename T>
|
||||
struct greater {
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator()(const T& lhs, const T& rhs) const {
|
||||
return lhs > rhs;
|
||||
}
|
||||
};
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Minimum/maximum operations <algorithm>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/// std::min
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE constexpr const T& min(const T& a, const T& b) {
|
||||
return (b < a) ? b : a;
|
||||
}
|
||||
|
||||
/// std::max
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE constexpr const T& max(const T& a, const T& b) {
|
||||
return (a < b) ? b : a;
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
//-----------------------------------------------------------------------------
|
||||
// Methods on std::pair
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
using std::pair;
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator==(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return (lhs.first == rhs.first) && (lhs.second == rhs.second);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator!=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return (lhs.first != rhs.first) && (lhs.second != rhs.second);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator<(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return (lhs.first < rhs.first) ? true : (rhs.first < lhs.first) ? false
|
||||
: (lhs.second < rhs.second);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator<=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return !(rhs < lhs);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator>(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return (rhs < lhs);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE constexpr bool operator>=(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
|
||||
return !(lhs < rhs);
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
CUTLASS_HOST_DEVICE std::pair<T1, T2> make_pair(T1 t, T2 u) {
|
||||
std::pair<T1, T2> retval;
|
||||
retval.first = t;
|
||||
retval.second = u;
|
||||
return retval;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace platform
|
||||
|
||||
/******************************************************************************
|
||||
* Implementations of C++ 11/14/17/... STL features
|
||||
******************************************************************************/
|
||||
|
||||
namespace platform {
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Integral constant helper types <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::integral_constant
|
||||
template <typename value_t, value_t V>
|
||||
struct integral_constant;
|
||||
|
||||
/// std::integral_constant
|
||||
template <typename value_t, value_t V>
|
||||
struct integral_constant {
|
||||
static const value_t value = V;
|
||||
|
||||
typedef value_t value_type;
|
||||
typedef integral_constant<value_t, V> type;
|
||||
|
||||
CUTLASS_HOST_DEVICE operator value_type() const { return value; }
|
||||
|
||||
CUTLASS_HOST_DEVICE const value_type operator()() const { return value; }
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::integral_constant;
|
||||
using std::pair;
|
||||
|
||||
#endif
|
||||
|
||||
/// The type used as a compile-time boolean with true value.
|
||||
typedef integral_constant<bool, true> true_type;
|
||||
|
||||
/// The type used as a compile-time boolean with false value.
|
||||
typedef integral_constant<bool, false> false_type;
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
|
||||
|
||||
/// std::bool_constant
|
||||
template <bool V>
|
||||
struct bool_constant : platform::integral_constant<bool, V> {};
|
||||
|
||||
#else
|
||||
|
||||
using std::bool_constant;
|
||||
|
||||
#endif
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700))
|
||||
|
||||
/// std::nullptr_t
|
||||
struct nullptr_t {};
|
||||
|
||||
#else
|
||||
|
||||
using std::nullptr_t;
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Conditional metaprogramming <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600))
|
||||
|
||||
/// std::enable_if (true specialization)
|
||||
template <bool C, typename T = void>
|
||||
struct enable_if {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::enable_if (false specialization)
|
||||
template <typename T>
|
||||
struct enable_if<false, T> {};
|
||||
|
||||
/// std::conditional (true specialization)
|
||||
template <bool B, class T, class F>
|
||||
struct conditional {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::conditional (false specialization)
|
||||
template <class T, class F>
|
||||
struct conditional<false, T, F> {
|
||||
typedef F type;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::enable_if;
|
||||
using std::conditional;
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Const/volatility specifiers <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::remove_const (non-const specialization)
|
||||
template <typename T>
|
||||
struct remove_const {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::remove_const (const specialization)
|
||||
template <typename T>
|
||||
struct remove_const<const T> {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::remove_volatile (non-volatile specialization)
|
||||
template <typename T>
|
||||
struct remove_volatile {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::remove_volatile (volatile specialization)
|
||||
template <typename T>
|
||||
struct remove_volatile<volatile T> {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
/// std::remove_cv
|
||||
template <typename T>
|
||||
struct remove_cv {
|
||||
typedef typename remove_volatile<typename remove_const<T>::type>::type type;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::remove_const;
|
||||
using std::remove_volatile;
|
||||
using std::remove_cv;
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Type relationships <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::is_same (false specialization)
|
||||
template <typename A, typename B>
|
||||
struct is_same : false_type {};
|
||||
|
||||
/// std::is_same (true specialization)
|
||||
template <typename A>
|
||||
struct is_same<A, A> : true_type {};
|
||||
|
||||
/// Helper for std::is_base_of
|
||||
template <typename BaseT, typename DerivedT>
|
||||
struct is_base_of_helper {
|
||||
typedef char (&yes)[1];
|
||||
typedef char (&no)[2];
|
||||
|
||||
template <typename B, typename D>
|
||||
struct dummy {
|
||||
CUTLASS_HOST_DEVICE operator B*() const;
|
||||
CUTLASS_HOST_DEVICE operator D*();
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE static yes check(DerivedT*, T);
|
||||
|
||||
CUTLASS_HOST_DEVICE static no check(BaseT*, int);
|
||||
|
||||
static const bool value = sizeof(check(dummy<BaseT, DerivedT>(), int())) == sizeof(yes);
|
||||
};
|
||||
|
||||
/// std::is_base_of
|
||||
template <typename BaseT, typename DerivedT>
|
||||
struct is_base_of
|
||||
: integral_constant<bool,
|
||||
(is_base_of_helper<typename remove_cv<BaseT>::type,
|
||||
typename remove_cv<DerivedT>::type>::value) ||
|
||||
(is_same<typename remove_cv<BaseT>::type,
|
||||
typename remove_cv<DerivedT>::type>::value)> {};
|
||||
|
||||
#else
|
||||
|
||||
using std::is_same;
|
||||
using std::is_base_of;
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Type properties <type_traits>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::is_volatile
|
||||
template <typename T>
|
||||
struct is_volatile : false_type {};
|
||||
template <typename T>
|
||||
struct is_volatile<volatile T> : true_type {};
|
||||
|
||||
/// Helper for std::is_pointer (false specialization)
|
||||
template <typename T>
|
||||
struct is_pointer_helper : false_type {};
|
||||
|
||||
/// Helper for std::is_pointer (true specialization)
|
||||
template <typename T>
|
||||
struct is_pointer_helper<T*> : true_type {};
|
||||
|
||||
/// std::is_pointer
|
||||
template <typename T>
|
||||
struct is_pointer : is_pointer_helper<typename remove_cv<T>::type> {};
|
||||
|
||||
/// std::is_void
|
||||
template <typename T>
|
||||
struct is_void : is_same<void, typename remove_cv<T>::type> {};
|
||||
|
||||
/// std::is_integral
|
||||
template <typename T>
|
||||
struct is_integral : false_type {};
|
||||
template <>
|
||||
struct is_integral<char> : true_type {};
|
||||
template <>
|
||||
struct is_integral<signed char> : true_type {};
|
||||
template <>
|
||||
struct is_integral<unsigned char> : true_type {};
|
||||
template <>
|
||||
struct is_integral<short> : true_type {};
|
||||
template <>
|
||||
struct is_integral<unsigned short> : true_type {};
|
||||
template <>
|
||||
struct is_integral<int> : true_type {};
|
||||
template <>
|
||||
struct is_integral<unsigned int> : true_type {};
|
||||
template <>
|
||||
struct is_integral<long> : true_type {};
|
||||
template <>
|
||||
struct is_integral<unsigned long> : true_type {};
|
||||
template <>
|
||||
struct is_integral<long long> : true_type {};
|
||||
template <>
|
||||
struct is_integral<unsigned long long> : true_type {};
|
||||
template <typename T>
|
||||
struct is_integral<volatile T> : is_integral<T> {};
|
||||
template <typename T>
|
||||
struct is_integral<const T> : is_integral<T> {};
|
||||
template <typename T>
|
||||
struct is_integral<const volatile T> : is_integral<T> {};
|
||||
|
||||
/// std::is_floating_point
|
||||
template <typename T>
|
||||
struct is_floating_point
|
||||
: integral_constant<bool,
|
||||
(is_same<float, typename remove_cv<T>::type>::value ||
|
||||
is_same<double, typename remove_cv<T>::type>::value)> {};
|
||||
|
||||
/// std::is_arithmetic
|
||||
template <typename T>
|
||||
struct is_arithmetic
|
||||
: integral_constant<bool, (is_integral<T>::value || is_floating_point<T>::value)> {};
|
||||
|
||||
/// std::is_fundamental
|
||||
template <typename T>
|
||||
struct is_fundamental
|
||||
: integral_constant<bool,
|
||||
(is_arithmetic<T>::value || is_void<T>::value ||
|
||||
is_same<nullptr_t, typename remove_cv<T>::type>::value)> {};
|
||||
|
||||
#else
|
||||
|
||||
using std::is_volatile;
|
||||
using std::is_pointer;
|
||||
using std::is_void;
|
||||
using std::is_integral;
|
||||
using std::is_floating_point;
|
||||
using std::is_arithmetic;
|
||||
using std::is_fundamental;
|
||||
|
||||
#endif
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800)) || \
|
||||
(defined(__GNUG__) && (__GNUC__ < 5))
|
||||
|
||||
/**
|
||||
* std::is_trivially_copyable
|
||||
*
|
||||
* This implementation only evaluates true if T is fundamental or pointer
|
||||
*
|
||||
* Without help from partial template specializations provided by the user for
|
||||
* a specific class or struct, this trait will never report that the specified
|
||||
* class or struct is trivially-copyable ; this is always safe,
|
||||
* if possibly sub-optimal.
|
||||
*/
|
||||
template <typename T>
|
||||
struct is_trivially_copyable
|
||||
: integral_constant<bool, (is_fundamental<T>::value || is_pointer<T>::value)> {};
|
||||
|
||||
#else
|
||||
|
||||
using std::is_trivially_copyable;
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Alignment and layout utilities
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
|
||||
|
||||
/// std::alignment_of
|
||||
template <typename value_t>
|
||||
struct alignment_of {
|
||||
struct pad {
|
||||
value_t val;
|
||||
char byte;
|
||||
};
|
||||
|
||||
enum { value = sizeof(pad) - sizeof(value_t) };
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename value_t>
|
||||
struct alignment_of : std::alignment_of<value_t> {};
|
||||
|
||||
#endif
|
||||
|
||||
/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */
|
||||
template <>
|
||||
struct alignment_of<int4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<uint4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<float4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<long4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<ulong4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<longlong2> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<ulonglong2> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<double2> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<longlong4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<ulonglong4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
template <>
|
||||
struct alignment_of<double4> {
|
||||
enum { value = 16 };
|
||||
};
|
||||
|
||||
// Specializations for volatile/const qualified types
|
||||
template <typename value_t>
|
||||
struct alignment_of<volatile value_t> : alignment_of<value_t> {};
|
||||
template <typename value_t>
|
||||
struct alignment_of<const value_t> : alignment_of<value_t> {};
|
||||
template <typename value_t>
|
||||
struct alignment_of<const volatile value_t> : alignment_of<value_t> {};
|
||||
|
||||
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800))
|
||||
|
||||
template <size_t Align>
|
||||
struct aligned_chunk;
|
||||
template <>
|
||||
struct __align__(1) aligned_chunk<1> {
|
||||
uint8_t buff;
|
||||
};
|
||||
template <>
|
||||
struct __align__(2) aligned_chunk<2> {
|
||||
uint16_t buff;
|
||||
};
|
||||
template <>
|
||||
struct __align__(4) aligned_chunk<4> {
|
||||
uint32_t buff;
|
||||
};
|
||||
template <>
|
||||
struct __align__(8) aligned_chunk<8> {
|
||||
uint32_t buff[2];
|
||||
};
|
||||
template <>
|
||||
struct __align__(16) aligned_chunk<16> {
|
||||
uint32_t buff[4];
|
||||
};
|
||||
template <>
|
||||
struct __align__(32) aligned_chunk<32> {
|
||||
uint32_t buff[8];
|
||||
};
|
||||
template <>
|
||||
struct __align__(64) aligned_chunk<64> {
|
||||
uint32_t buff[16];
|
||||
};
|
||||
template <>
|
||||
struct __align__(128) aligned_chunk<128> {
|
||||
uint32_t buff[32];
|
||||
};
|
||||
template <>
|
||||
struct __align__(256) aligned_chunk<256> {
|
||||
uint32_t buff[64];
|
||||
};
|
||||
template <>
|
||||
struct __align__(512) aligned_chunk<512> {
|
||||
uint32_t buff[128];
|
||||
};
|
||||
template <>
|
||||
struct __align__(1024) aligned_chunk<1024> {
|
||||
uint32_t buff[256];
|
||||
};
|
||||
template <>
|
||||
struct __align__(2048) aligned_chunk<2048> {
|
||||
uint32_t buff[512];
|
||||
};
|
||||
template <>
|
||||
struct __align__(4096) aligned_chunk<4096> {
|
||||
uint32_t buff[1024];
|
||||
};
|
||||
|
||||
/// std::aligned_storage
|
||||
template <size_t Len, size_t Align>
|
||||
struct aligned_storage {
|
||||
typedef aligned_chunk<Align> type[Len / sizeof(aligned_chunk<Align>)];
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using std::aligned_storage;
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Default deleter
|
||||
template <typename T>
|
||||
struct default_delete {
|
||||
void operator()(T* ptr) const { delete ptr; }
|
||||
};
|
||||
|
||||
/// Partial specialization for deleting array types
|
||||
template <typename T>
|
||||
struct default_delete<T[]> {
|
||||
void operator()(T* ptr) const { delete[] ptr; }
|
||||
};
|
||||
|
||||
/// std::unique_ptr
|
||||
template <class T, class Deleter = default_delete<T> >
|
||||
class unique_ptr {
|
||||
public:
|
||||
typedef T* pointer;
|
||||
typedef T element_type;
|
||||
typedef Deleter deleter_type;
|
||||
|
||||
private:
|
||||
/// Pointer to memory
|
||||
pointer _ptr;
|
||||
|
||||
/// Deleter
|
||||
deleter_type _deleter;
|
||||
|
||||
public:
|
||||
unique_ptr() : _ptr(nullptr) {}
|
||||
unique_ptr(pointer p) : _ptr(p) {}
|
||||
|
||||
~unique_ptr() {
|
||||
if (_ptr) {
|
||||
_deleter(_ptr);
|
||||
}
|
||||
}
|
||||
/// Returns a pointer to the managed object or nullptr if no object is owned.
|
||||
pointer get() const noexcept { return _ptr; }
|
||||
|
||||
/// Releases ownership of the managed object, if any
|
||||
pointer release() noexcept {
|
||||
pointer p(_ptr);
|
||||
_ptr = nullptr;
|
||||
return p;
|
||||
}
|
||||
|
||||
/// Replaces the managed object, deleting the old object.
|
||||
void reset(pointer p = pointer()) noexcept {
|
||||
pointer old_ptr = _ptr;
|
||||
_ptr = p;
|
||||
if (old_ptr != nullptr) {
|
||||
get_deleter()(old_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
/// Swaps the managed objects with *this and another unique_ptr
|
||||
void swap(unique_ptr& other) noexcept { std::swap(_ptr, other._ptr); }
|
||||
|
||||
/// Returns the deleter object
|
||||
Deleter& get_deleter() noexcept { return _deleter; }
|
||||
|
||||
/// Returns the deleter object
|
||||
Deleter const& get_deleter() const noexcept { return _deleter; }
|
||||
|
||||
/// Checks whether an object is owned
|
||||
operator bool() const noexcept { return _ptr != nullptr; }
|
||||
|
||||
/// Dereferences the unique_ptr
|
||||
T& operator*() const { return *_ptr; }
|
||||
|
||||
/// Returns a pointer to the managed object
|
||||
pointer operator->() const noexcept { return _ptr; }
|
||||
|
||||
/// Array access to managed object
|
||||
T& operator[](size_t i) const { return _ptr[i]; }
|
||||
};
|
||||
|
||||
/// Specializes the swap algorithm
|
||||
template <typename T, typename Deleter>
|
||||
void swap(unique_ptr<T, Deleter>& lhs, unique_ptr<T, Deleter>& rhs) noexcept {
|
||||
lhs.swap(rhs);
|
||||
}
|
||||
#endif
|
||||
|
||||
}; // namespace platform
|
||||
}; // namespace cutlass
|
||||
@ -1,72 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Pure virtual base class for printable types
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* printable_t
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Pure virtual base class for printable types
|
||||
*/
|
||||
struct printable_t
|
||||
{
|
||||
/// Returns the instance as a string
|
||||
__host__ __device__ inline
|
||||
virtual char const* to_string() const = 0;
|
||||
|
||||
/// Insert the formatted instance into the output stream
|
||||
virtual void print(std::ostream& out) const = 0;
|
||||
|
||||
/// Destructor
|
||||
virtual ~printable_t() {}
|
||||
};
|
||||
|
||||
|
||||
/// Insert the formatted \p printable into the output stream
|
||||
std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
printable_t const& printable)
|
||||
{
|
||||
printable.print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,82 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Umbrella header file for utilities
|
||||
*/
|
||||
|
||||
#include "debug.h"
|
||||
#include "device_introspection.h"
|
||||
#include "io_intrinsics.h"
|
||||
#include "math.h"
|
||||
#include "nv_std.h"
|
||||
#include "printable.h"
|
||||
#include "matrix_transform.h"
|
||||
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* int_constant
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Shorthand for nv_std::integral_constant of int32_t type
|
||||
*/
|
||||
template <int V>
|
||||
struct int_constant : nv_std::integral_constant<int32_t, V>
|
||||
{};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Uninitialized
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \brief A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions
|
||||
*/
|
||||
template <typename T>
|
||||
struct __align__(16) uninitialized
|
||||
{
|
||||
/// Backing storage
|
||||
uint8_t storage[sizeof(T)];
|
||||
|
||||
/// Alias
|
||||
__host__ __device__ __forceinline__ T& alias()
|
||||
{
|
||||
return reinterpret_cast<T&>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
229
cutlass/vector.h
Normal file
229
cutlass/vector.h
Normal file
@ -0,0 +1,229 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a 1D vector of elements held in the registers of each thread.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <size_t kAlignment_>
|
||||
struct AlignedStruct {};
|
||||
|
||||
template <>
|
||||
struct __align__(1) AlignedStruct<1>{};
|
||||
template <>
|
||||
struct __align__(2) AlignedStruct<2>{};
|
||||
template <>
|
||||
struct __align__(4) AlignedStruct<4>{};
|
||||
template <>
|
||||
struct __align__(8) AlignedStruct<8>{};
|
||||
template <>
|
||||
struct __align__(16) AlignedStruct<16>{};
|
||||
template <>
|
||||
struct __align__(32) AlignedStruct<32>{};
|
||||
template <>
|
||||
struct __align__(64) AlignedStruct<64>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_>
|
||||
union Vector {
|
||||
/// The scalar type.
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
// Make sure that the vector type makes sense.
|
||||
static_assert(kVectorSize <= 16, "Vector type is too large");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The associated array of scalars.
|
||||
Scalar scalars[kLanes];
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
|
||||
template <int kLanes_>
|
||||
union Vector<half, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef half Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
// Make sure that the vector type makes sense.
|
||||
static_assert(kVectorSize <= size_t(16), "Vector type is too large");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The associated array of scalars.
|
||||
uint16_t scalars[kLanes];
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const {
|
||||
return reinterpret_cast<Scalar const&>(scalars[i]);
|
||||
}
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return reinterpret_cast<Scalar&>(scalars[i]); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
CUTLASS_DEVICE void make_zero(Scalar_& x) {
|
||||
x = Scalar_(0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_, int kLanes_ = 1>
|
||||
struct Vectorize {
|
||||
typedef Vector<Element_, kLanes_> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_>
|
||||
struct Vectorize<Element_, 1> {
|
||||
typedef Element_ Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_>
|
||||
CUTLASS_DEVICE void make_zero(Vector<Scalar_, kLanes_>& vec) {
|
||||
for (int i = 0; i < Vector<Scalar_, kLanes_>::kRegisters; ++i) {
|
||||
vec.registers[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// cutlass::Extent similar to std::extent but applicable to CUTLASS types
|
||||
//
|
||||
|
||||
/// Returns the extent of a scalar or vector
|
||||
template <typename T>
|
||||
struct Extent {
|
||||
static size_t const kValue = 1;
|
||||
};
|
||||
|
||||
/// Returns the number of lanes of a vector if need be
|
||||
template <typename T, int Lanes>
|
||||
struct Extent<Vector<T, Lanes> > {
|
||||
static size_t const kValue = Lanes;
|
||||
};
|
||||
|
||||
/// Returns the number of lanes of a vector if need be
|
||||
template <typename T, int Lanes>
|
||||
struct Extent<Vector<T, Lanes> const> {
|
||||
static size_t const kValue = Lanes;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Traits describing properties of vectors and scalar-as-vectors
|
||||
template <typename T>
|
||||
struct VectorTraits {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = 1;
|
||||
|
||||
/// True if the type is actually a cutlass::Vector, otherwise false
|
||||
static bool const IsVector = false;
|
||||
|
||||
/// Type that is always a vector
|
||||
typedef Vector<T, 1> Vector;
|
||||
};
|
||||
|
||||
/// Partial specialization for actual cutlass::Vector
|
||||
template <typename T, int Lanes>
|
||||
struct VectorTraits<Vector<T, Lanes> > {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = Lanes;
|
||||
|
||||
/// Type is actually a cutlass::Vector
|
||||
static bool const IsVector = true;
|
||||
|
||||
/// Type that is always a Vector
|
||||
typedef Vector<T, Lanes> Vector;
|
||||
};
|
||||
|
||||
/// Partial specialization for actual cutlass::Vector
|
||||
template <typename T, int Lanes>
|
||||
struct VectorTraits<Vector<T, Lanes> const> {
|
||||
/// Scalar type
|
||||
typedef T Scalar;
|
||||
|
||||
/// Number of lanes of vector
|
||||
static int const kLanes = Lanes;
|
||||
|
||||
/// Type is actually a cutlass::Vector
|
||||
static bool const IsVector = true;
|
||||
|
||||
/// Type that is always a Vector
|
||||
typedef Vector<T, Lanes> Vector;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
193
cutlass/wmma_matrix.h
Normal file
193
cutlass/wmma_matrix.h
Normal file
@ -0,0 +1,193 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Abstractions for loading and storing matrices using the CUDA WMMA API.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
|
||||
|
||||
// Dependent header files should use the following macro to guard all code using
|
||||
// nvcuda::wmma:: to enable compilation for CUDA Compute Capabilities < sm_70.
|
||||
// Earlier shader models not support Tensor Cores.
|
||||
#define CUTLASS_USE_WMMA_API
|
||||
|
||||
#include "stdio.h"
|
||||
|
||||
#include <crt/mma.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
|
||||
template <MatrixLayout::Kind kLayout_>
|
||||
struct WmmaLayout {
|
||||
typedef nvcuda::wmma::col_major Layout;
|
||||
};
|
||||
|
||||
/// Statically maps cutlass::MatrixLayout => nvcuda::wmma layout tags
|
||||
template <>
|
||||
struct WmmaLayout<MatrixLayout::kRowMajor> {
|
||||
typedef nvcuda::wmma::row_major Layout;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment load and store operations
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename WmmaShape_>
|
||||
struct WmmaMatrix {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for A operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::matrix_a,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for B operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::matrix_b,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment accessors for C operand
|
||||
template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
|
||||
struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
|
||||
: public nvcuda::wmma::fragment<
|
||||
/// The nvcuda::wmma operand name.
|
||||
nvcuda::wmma::accumulator,
|
||||
/// The dimensions.
|
||||
WmmaShape_::kW,
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_> {
|
||||
/// This type.
|
||||
typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = kLayout_;
|
||||
|
||||
/// Fill-in the element.
|
||||
CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
|
||||
nvcuda::wmma::fill_fragment(*this, x);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Load from memory.
|
||||
CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
|
||||
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
*this,
|
||||
pointer,
|
||||
stride,
|
||||
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
|
||||
/// Store to memory.
|
||||
CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
|
||||
bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
pointer,
|
||||
*this,
|
||||
stride,
|
||||
kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
7
cutlass_test/.gitignore
vendored
7
cutlass_test/.gitignore
vendored
@ -1,7 +0,0 @@
|
||||
/bin/
|
||||
/gemm-GPU.csv
|
||||
/gemm-REF.csv
|
||||
/a.csv
|
||||
/b.csv
|
||||
/gp100_schmoo/
|
||||
/ignore/
|
||||
@ -1,180 +0,0 @@
|
||||
#/******************************************************************************
|
||||
# * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
# *
|
||||
# * Redistribution and use in source and binary forms, with or without
|
||||
# * modification, are permitted provided that the following conditions are met:
|
||||
# * * Redistributions of source code must retain the above copyright
|
||||
# * notice, this list of conditions and the following disclaimer.
|
||||
# * * Redistributions in binary form must reproduce the above copyright
|
||||
# * notice, this list of conditions and the following disclaimer in the
|
||||
# * documentation and/or other materials provided with the distribution.
|
||||
# * * Neither the name of the NVIDIA CORPORATION nor the
|
||||
# * names of its contributors may be used to endorse or promote products
|
||||
# * derived from this software without specific prior written permission.
|
||||
# *
|
||||
# * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
# * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
# * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
# * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
# * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
# * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
# * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
# * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
# *
|
||||
# ******************************************************************************/
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
#
|
||||
# Makefile usage
|
||||
#
|
||||
# make <target> sm=<XX[,YY,ZZ,..]> [transpose=<nn*|nt|tn|tt>] [verbose=<0*|1>] [keep=<0*|1>]
|
||||
#
|
||||
# * : default
|
||||
#
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
TEST_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
|
||||
|
||||
include ../common.mk
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Commandline Options
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
ifdef transpose
|
||||
TRANSPOSE := $(transpose)
|
||||
else
|
||||
TRANSPOSE := nn
|
||||
endif
|
||||
|
||||
# If defined, GEMMs only compiled with specified alignment restrictions on A and B
|
||||
# matrices. Otherwise, kernels are compiled for all feasible alignment options, and
|
||||
# the appropriate kernel is selected.
|
||||
ifdef alignment
|
||||
DEFINES += -DGEMM_ALIGNMENT=$(alignment)
|
||||
endif
|
||||
|
||||
# If defined as false, ragged handling can be disabled.
|
||||
ifdef ragged
|
||||
DEFINES += -DGEMM_RAGGED=$(ragged)
|
||||
endif
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Include and Library paths
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
INC += -I$(TEST_DIR)
|
||||
INC += -I$(BASE_DIR)
|
||||
|
||||
LIBS += -lcublas
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Preprocessor definitions
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
ifeq (nt, $(TRANSPOSE))
|
||||
DEFINES += -DTRANSPOSE_B
|
||||
else ifeq (tn, $(TRANSPOSE))
|
||||
DEFINES += -DTRANSPOSE_A
|
||||
|
||||
else ifeq (tt, $(TRANSPOSE))
|
||||
DEFINES += -DTRANSPOSE_A
|
||||
DEFINES += -DTRANSPOSE_B
|
||||
endif
|
||||
|
||||
NVCCFLAGS += -std=c++11
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Dependency Lists
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
DEPS := $(call rwildcard, $(BASE_DIR),*.h) \
|
||||
$(call rwildcard, $(BASE_DIR)cgl,*.h) \
|
||||
$(BASE_DIR)common.mk \
|
||||
$(TEST_DIR)Makefile
|
||||
|
||||
|
||||
ALL := sgemm \
|
||||
dgemm \
|
||||
hgemm \
|
||||
igemm
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make default
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
default:
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make clean
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
clean :
|
||||
rm -f bin/*
|
||||
rm -f *.i* *.cubin *.cu.c *.cudafe* *.fatbin.c *.ptx *.hash *.cu.cpp *.o *.obj* *dlink.* *.res *.fatbin *.module_id
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make all
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
all : $(ALL)
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make sgemm
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
sgemm: bin/sgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
|
||||
|
||||
bin/sgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
|
||||
mkdir -p bin
|
||||
$(NVCC) -DTEST_SGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make dgemm
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
dgemm: bin/dgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
|
||||
|
||||
bin/dgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
|
||||
mkdir -p bin
|
||||
$(NVCC) -DTEST_DGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make hgemm
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
hgemm: bin/hgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
|
||||
|
||||
bin/hgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
|
||||
mkdir -p bin
|
||||
$(NVCC) -DTEST_HGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make igemm
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
igemm: bin/igemm_$(TRANSPOSE)_$(BIN_SUFFIX)
|
||||
|
||||
bin/igemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
|
||||
mkdir -p bin
|
||||
$(NVCC) -DTEST_IGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# make wgemm
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
wgemm: bin/wgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
|
||||
|
||||
bin/wgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
|
||||
mkdir -p bin
|
||||
$(NVCC) -DTEST_WGEMM -DWMMA $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
|
||||
|
||||
@ -1,300 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* C++ interface for dispatching CUBLAS GEMM calls
|
||||
*/
|
||||
|
||||
#include <cublas_v2.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* cuBLAS dispatch entrypoints
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Dispatch cuBLAS igemm
|
||||
*/
|
||||
cublasStatus_t cublas_gemm_dispatch(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
|
||||
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
int32_t alpha, ///< Scalar used for multiplicands
|
||||
int8_t *d_a, ///< Device pointer to matrix A array values
|
||||
int8_t *d_b, ///< Device pointer to matrix B array values
|
||||
int32_t beta, ///< Scalar used for addend
|
||||
int32_t *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
return cublasGemmEx(
|
||||
cublas_handle,
|
||||
transform_a,
|
||||
transform_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(void*) &alpha,
|
||||
(void*) d_a,
|
||||
CUDA_R_8I,
|
||||
(transform_a == CUBLAS_OP_N) ? m : k,
|
||||
(void*) d_b,
|
||||
CUDA_R_8I,
|
||||
(transform_b == CUBLAS_OP_N) ? k : n,
|
||||
(void*) &beta,
|
||||
(void*) d_c,
|
||||
CUDA_R_32I,
|
||||
m,
|
||||
CUDA_R_32I,
|
||||
CUBLAS_GEMM_DFALT);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dispatch cuBLAS hgemm
|
||||
*/
|
||||
cublasStatus_t cublas_gemm_dispatch(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
|
||||
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
__half alpha, ///< Scalar used for multiplicands
|
||||
__half *d_a, ///< Device pointer to matrix A array values
|
||||
__half *d_b, ///< Device pointer to matrix B array values
|
||||
__half beta, ///< Scalar used for addend
|
||||
__half *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
return cublasHgemm(
|
||||
cublas_handle, transform_a, transform_b,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_a,
|
||||
(transform_a == CUBLAS_OP_N) ? m : k,
|
||||
d_b,
|
||||
(transform_b == CUBLAS_OP_N) ? k : n,
|
||||
&beta,
|
||||
d_c,
|
||||
m);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dispatch cuBLAS sgemm
|
||||
*/
|
||||
cublasStatus_t cublas_gemm_dispatch(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
|
||||
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
float alpha, ///< Scalar used for multiplicands
|
||||
float *d_a, ///< Device pointer to matrix A array values
|
||||
float *d_b, ///< Device pointer to matrix B array values
|
||||
float beta, ///< Scalar used for addend
|
||||
float *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
return cublasSgemm(
|
||||
cublas_handle, transform_a, transform_b,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_a,
|
||||
(transform_a == CUBLAS_OP_N) ? m : k,
|
||||
d_b,
|
||||
(transform_b == CUBLAS_OP_N) ? k : n,
|
||||
&beta,
|
||||
d_c,
|
||||
m);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Dispatch cuBLAS dgemm
|
||||
*/
|
||||
cublasStatus_t cublas_gemm_dispatch(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
|
||||
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
double alpha, ///< Scalar used for multiplicands
|
||||
double *d_a, ///< Device pointer to matrix A array values
|
||||
double *d_b, ///< Device pointer to matrix B array values
|
||||
double beta, ///< Scalar used for addend
|
||||
double *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
return cublasDgemm(
|
||||
cublas_handle, transform_a, transform_b,
|
||||
m, n, k,
|
||||
&alpha,
|
||||
d_a, (transform_a == CUBLAS_OP_N) ? m : k,
|
||||
d_b, (transform_b == CUBLAS_OP_N) ? k : n,
|
||||
&beta,
|
||||
d_c, m);
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch cuBLAS Tensor Cores GEMM
|
||||
*/
|
||||
cublasStatus_t cublas_gemm_dispatch(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
|
||||
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
|
||||
int m, ///< Height in rows of op(A) and C
|
||||
int n, ///< Width in columns of op(B) and C
|
||||
int k, ///< Width in columns of op(A) and height in rows of op(B)
|
||||
float alpha, ///< Scalar used for multiplicands
|
||||
half *d_a, ///< Device pointer to matrix A array values
|
||||
half *d_b, ///< Device pointer to matrix B array values
|
||||
float beta, ///< Scalar used for addend
|
||||
float *d_c, ///< Device pointer to matrix C array values
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
return cublasGemmEx(
|
||||
cublas_handle,
|
||||
transform_a,
|
||||
transform_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
(void*) &alpha,
|
||||
(void*) d_a,
|
||||
CUDA_R_16F,
|
||||
(transform_a == CUBLAS_OP_N) ? m : k,
|
||||
(void*) d_b,
|
||||
CUDA_R_16F,
|
||||
(transform_b == CUBLAS_OP_N) ? k : n,
|
||||
(void*) &beta,
|
||||
(void*) d_c,
|
||||
CUDA_R_32F,
|
||||
m,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Uses cuBLAS to compute gemm on device matrices (unspecialized)
|
||||
*/
|
||||
template <
|
||||
gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category
|
||||
math_operation_class_t _math_op,
|
||||
matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B
|
||||
typename _value, ///< Multiplicand value type (matrices A and B)
|
||||
typename _accum ///< Accumulator value type (matrix C and scalars)
|
||||
>
|
||||
struct cublas_gemm
|
||||
{
|
||||
//
|
||||
// Type alias definitions
|
||||
//
|
||||
|
||||
static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy;
|
||||
static const math_operation_class_t math_op = _math_op;
|
||||
static const matrix_transform_t::kind_t TransformA = _TransformA;
|
||||
static const matrix_transform_t::kind_t TransformB = _TransformB;
|
||||
|
||||
using value_t = _value;
|
||||
using accum_t = _accum;
|
||||
|
||||
/// Launches a GEMM
|
||||
gemm::launch_configuration operator()(
|
||||
cublasHandle_t cublas_handle, ///< CUBLAS handle
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
value_t *A, ///< A matrix
|
||||
value_t *B, ///< B matrix
|
||||
accum_t *C, ///< C matrix
|
||||
accum_t alpha, ///< Scalar used for multiplicands
|
||||
accum_t beta, ///< Scalar used for addend
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
|
||||
{
|
||||
cublasStatus_t cublas_error = cublas_gemm_dispatch(
|
||||
cublas_handle,
|
||||
(cublasOperation_t) TransformA,
|
||||
(cublasOperation_t) TransformB,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
A,
|
||||
B,
|
||||
beta,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
|
||||
cudaError_t error;
|
||||
if (cublas_error != CUBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
if (cublas_error == CUBLAS_STATUS_NOT_SUPPORTED) {
|
||||
return gemm::launch_configuration(cudaErrorInvalidValue);
|
||||
}
|
||||
|
||||
error = cudaGetLastError();
|
||||
if (error == cudaSuccess) {
|
||||
return gemm::launch_configuration(cudaErrorUnknown);
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
// Check for failure to launch
|
||||
if (CUDA_PERROR_DEBUG(error = cudaPeekAtLastError()))
|
||||
return gemm::launch_configuration(error);
|
||||
|
||||
// Sync the stream if specified to flush runtime errors
|
||||
if (debug_synchronous && (CUDA_PERROR_DEBUG(error = cudaStreamSynchronize(stream))))
|
||||
return gemm::launch_configuration(error);
|
||||
|
||||
return gemm::launch_configuration(error);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,261 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file Dispatch routines for CUTLASS GEMM kernels
|
||||
*/
|
||||
|
||||
// CUDA includes
|
||||
#include <cublas_v2.h>
|
||||
|
||||
// Cutlass GEMM API
|
||||
#include <cutlass/util/util.h>
|
||||
#include <cutlass/gemm/dispatch.h>
|
||||
#include <cutlass/gemm/epilogue_function.h>
|
||||
|
||||
// Test utilities
|
||||
#include "util/type_conversion.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Cutlass dispatch entrypoints
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Compile-time overrides for alignment and ragged handling.
|
||||
//
|
||||
|
||||
// If zero, all feasible alignment options are supported.
|
||||
#ifndef GEMM_ALIGNMENT
|
||||
#define GEMM_ALIGNMENT 0
|
||||
#endif
|
||||
|
||||
// If true, kernels are compiled with ragged handling enabled.
|
||||
#ifndef GEMM_RAGGED
|
||||
#define GEMM_RAGGED true
|
||||
#endif
|
||||
|
||||
//
|
||||
// Dispatch logic given problem size specialization, math operation class, layout
|
||||
// and type of operands, and epilogue operation.
|
||||
//
|
||||
|
||||
/**
|
||||
* Cutlass GEMM dispatch
|
||||
*/
|
||||
template <
|
||||
gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category
|
||||
math_operation_class_t _math_op, // Indicates
|
||||
matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B
|
||||
typename _value, ///< Multiplicand value type (matrices A and B)
|
||||
typename _accum, ///< Accumulator value type (matrix C and scalars)
|
||||
typename _epilogue_op_t ///< Epilogue opeartion to update matrix C
|
||||
= gemm::blas_scaled_epilogue<_accum, _accum, _accum>
|
||||
>
|
||||
struct cutlass_gemm_dispatch
|
||||
{
|
||||
//
|
||||
// Type alias definitions
|
||||
//
|
||||
|
||||
static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy;
|
||||
static const math_operation_class_t math_op = _math_op;
|
||||
static const matrix_transform_t::kind_t TransformA = _TransformA;
|
||||
static const matrix_transform_t::kind_t TransformB = _TransformB;
|
||||
|
||||
using value_t = _value;
|
||||
using accum_t = _accum;
|
||||
using epilogue_op_t = _epilogue_op_t;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns leading dimension for A matrix operand
|
||||
int leading_dim_a(int m, int k) const
|
||||
{
|
||||
return (TransformA == matrix_transform_t::NonTranspose ? m : k);
|
||||
}
|
||||
|
||||
/// Returns leading dimension for B matrix operand
|
||||
int leading_dim_b(int k, int n) const
|
||||
{
|
||||
return (TransformB == matrix_transform_t::NonTranspose ? k : n);
|
||||
}
|
||||
|
||||
/// Launches a GEMM
|
||||
template <int operand_alignment, int accumulator_alignment>
|
||||
gemm::launch_configuration launch(
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
epilogue_op_t epilogue_op,
|
||||
value_t *A,
|
||||
value_t *B,
|
||||
accum_t *C,
|
||||
cudaStream_t stream = 0,
|
||||
bool debug_synchronous = false)
|
||||
{
|
||||
return gemm::device_gemm<
|
||||
TilingStrategy,
|
||||
math_op,
|
||||
TransformA,
|
||||
operand_alignment,
|
||||
TransformB,
|
||||
operand_alignment,
|
||||
value_t,
|
||||
accum_t,
|
||||
epilogue_op_t,
|
||||
accumulator_alignment>
|
||||
(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue_op,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
}
|
||||
|
||||
/// Dispatches a CUTLASS GEMM
|
||||
gemm::launch_configuration operator()(
|
||||
cublasHandle_t handle, ///< CUBLAS handle
|
||||
int m, ///< Rows of GEMM problem
|
||||
int n, ///< Columns of GEMM problem
|
||||
int k, ///< Inner dimension of GEMM problem
|
||||
value_t *A, ///< A matrix
|
||||
value_t *B, ///< B matrix
|
||||
accum_t *C, ///< C matrix
|
||||
accum_t alpha, ///< Scalar used for multiplicands
|
||||
accum_t beta, ///< Scalar used for addend
|
||||
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within.
|
||||
bool debug_synchronous = false) ///< Whether or not to synchronize the stream
|
||||
/// after every kernel launch to check for errors.
|
||||
{
|
||||
|
||||
// Forces kernel selection to choose specific alignment (in bytes)
|
||||
int const force_operand_alignment = GEMM_ALIGNMENT;
|
||||
|
||||
// Problem size must be multiple of the smallest vector load size
|
||||
typedef value_t operand_load_t;
|
||||
int const accumulator_alignment = sizeof(accum_t);
|
||||
|
||||
int const lda = leading_dim_a(m, k);
|
||||
int const ldb = leading_dim_b(k, n);
|
||||
|
||||
epilogue_op_t epilogue(alpha, beta);
|
||||
|
||||
// TODO: opportunity for metaprogramming loop
|
||||
|
||||
// Prefer the largest granularity of vector load that is compatible with
|
||||
// problem size and data alignment.
|
||||
if ((!force_operand_alignment || force_operand_alignment == 16) &&
|
||||
!((sizeof(operand_load_t) * lda) % 16) &&
|
||||
!((sizeof(operand_load_t) * ldb) % 16))
|
||||
{
|
||||
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 16)
|
||||
return launch<__NV_STD_MAX(16, sizeof(value_t)), accumulator_alignment>(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
#endif
|
||||
}
|
||||
else if ((!force_operand_alignment || force_operand_alignment == 8) &&
|
||||
!((sizeof(operand_load_t) * lda) % 8) &&
|
||||
!((sizeof(operand_load_t) * ldb) % 8))
|
||||
{
|
||||
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 8)
|
||||
return launch<__NV_STD_MAX(8, sizeof(value_t)), accumulator_alignment>(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
#endif
|
||||
}
|
||||
else if ((!force_operand_alignment || force_operand_alignment == 4) &&
|
||||
!((sizeof(operand_load_t) * lda) % 4) &&
|
||||
!((sizeof(operand_load_t) * ldb) % 4))
|
||||
{
|
||||
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 4)
|
||||
return launch<__NV_STD_MAX(4, sizeof(value_t)), accumulator_alignment>(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
#endif
|
||||
}
|
||||
else if ((!force_operand_alignment || force_operand_alignment == 2) &&
|
||||
!((sizeof(operand_load_t) * lda) % 2) &&
|
||||
!((sizeof(operand_load_t) * ldb) % 2))
|
||||
{
|
||||
// 16-bit alignment only supported for HGEMM
|
||||
#if defined(TEST_HGEMM) || defined(TEST_WGEMM)
|
||||
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 2)
|
||||
return launch<__NV_STD_MAX(2, sizeof(value_t)), accumulator_alignment>(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
epilogue,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
stream,
|
||||
debug_synchronous);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
return gemm::launch_configuration(cudaErrorInvalidValue);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,572 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* \file gemm.cu
|
||||
* GEMM test driver
|
||||
*
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
#include <random>
|
||||
#include <stdint.h>
|
||||
|
||||
// CUBLAS GEMM API
|
||||
#include <cublas_v2.h>
|
||||
|
||||
// Set Cutlass debug macro to enable console printing of library errors
|
||||
#define DEBUG
|
||||
|
||||
#if defined(WMMA)
|
||||
// Conditionally include WMMA headers (CUDA 9 Preview Feature)
|
||||
#include <mma.h>
|
||||
#endif
|
||||
|
||||
// Cutlass GEMM API
|
||||
#include <cutlass/util/util.h>
|
||||
#include <cutlass/gemm/dispatch.h>
|
||||
#include <cutlass/gemm/epilogue_function.h>
|
||||
|
||||
// Test utilities
|
||||
#include "util/command_line.h"
|
||||
#include "util/half.h"
|
||||
#include "util/matrix.h"
|
||||
#include "util/timer.h"
|
||||
#include "util/type_conversion.h"
|
||||
|
||||
// Dispatch routines to CUBLAS and CUTLASS
|
||||
#include "cublas_dispatch.h"
|
||||
#include "cutlass_dispatch.h"
|
||||
|
||||
/******************************************************************************
|
||||
* Globals, constants and typedefs
|
||||
******************************************************************************/
|
||||
|
||||
using namespace cutlass;
|
||||
|
||||
/// CUBLAS handle
|
||||
cublasHandle_t g_cublas_handle;
|
||||
|
||||
/// The device-id of the current device
|
||||
int g_device_id = -1;
|
||||
|
||||
/// The number of timing iterations to invoke
|
||||
int g_timing_iterations = -1;
|
||||
|
||||
/// The number of randomly-sized problems to schmoo
|
||||
int g_schmoo = 0;
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Number generation
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Simple low-integer generator
|
||||
*/
|
||||
struct simple_gen
|
||||
{
|
||||
std::default_random_engine generator;
|
||||
std::uniform_int_distribution<int> distribution;
|
||||
|
||||
/// Constructor
|
||||
simple_gen(int max) : distribution(max * -1, max)
|
||||
{}
|
||||
|
||||
/// Functor
|
||||
int operator()()
|
||||
{
|
||||
return distribution(generator);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Test execution
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
/**
|
||||
* Compute C = (alpha * A * B) + (beta * C)
|
||||
*/
|
||||
template <
|
||||
typename test_func_t, ///< Test function type
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t> ///< Accumulator value type (matrix C and scalars)
|
||||
bool test(
|
||||
int m, ///< Height of C in rows
|
||||
int n, ///< Width of C in columns
|
||||
int k, ///< Width (height) of A (B)
|
||||
accum_t alpha, ///< Multiplicand scalar
|
||||
accum_t beta) ///< Addend scalar
|
||||
{
|
||||
cudaStream_t stream = 0;
|
||||
|
||||
//
|
||||
// Initialize matrices
|
||||
//
|
||||
|
||||
matrix<value_t> A(
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? m : k,
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? k : m);
|
||||
|
||||
matrix<value_t> B(
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? k : n,
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? n : k);
|
||||
|
||||
matrix<accum_t> C(m, n);
|
||||
|
||||
// initialized matrices with small values precisely representable as integers
|
||||
simple_gen a_gen(3);
|
||||
simple_gen b_gen(5);
|
||||
A.fill_random(a_gen);
|
||||
B.fill_random(b_gen);
|
||||
C.fill_ramp(0,0);
|
||||
|
||||
// // Alternatively, initialize with procedural values to simplify debugging incorrect results
|
||||
// A.fill_ramp(1,2);
|
||||
// B.fill_ramp(1,1);
|
||||
|
||||
// Sync to device
|
||||
A.sync_device();
|
||||
B.sync_device();
|
||||
C.sync_device();
|
||||
|
||||
CUDA_PERROR(cudaPeekAtLastError());
|
||||
CUDA_PERROR(cudaDeviceSynchronize());
|
||||
|
||||
//
|
||||
// Run test once with debug-synchronous enabled and check result
|
||||
//
|
||||
|
||||
if (!g_schmoo) printf("\n");
|
||||
|
||||
test_func_t test_func;
|
||||
|
||||
C.fill_ramp(0, 0);
|
||||
C.sync_device();
|
||||
|
||||
cudaError_t error = test_func(
|
||||
g_cublas_handle,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
A.d_data(),
|
||||
B.d_data(),
|
||||
C.d_data(),
|
||||
alpha,
|
||||
beta,
|
||||
stream,
|
||||
!g_schmoo).result;
|
||||
|
||||
bool not_applicable = (error == cudaErrorInvalidValue);
|
||||
bool is_failed = false;
|
||||
if (not_applicable)
|
||||
{
|
||||
printf(", NA");
|
||||
}
|
||||
else
|
||||
{
|
||||
CUDA_PERROR(error);
|
||||
|
||||
// Compute reference check if wont take too long on CPU
|
||||
if ((!g_schmoo) && (m * n <= 1024 * 1024))
|
||||
{
|
||||
matrix<accum_t> ref_C(m, n);
|
||||
ref_C.fill_ramp(0, 0);
|
||||
ref_C.gemm(TransformA, TransformB, alpha, A, B, beta);
|
||||
C.sync_host();
|
||||
|
||||
is_failed = (C != ref_C);
|
||||
|
||||
if (!g_schmoo)
|
||||
{
|
||||
if (is_failed)
|
||||
{
|
||||
printf("FAIL, ");
|
||||
std::ofstream file_a("a.csv");
|
||||
A.write_matrix(file_a);
|
||||
std::ofstream file_b("b.csv");
|
||||
B.write_matrix(file_b);
|
||||
std::ofstream file_d("gemm-REF.csv");
|
||||
ref_C.write_matrix(file_d);
|
||||
std::ofstream file_c("gemm-GPU.csv");
|
||||
C.write_matrix(file_c);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("PASS, ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
//
|
||||
// Warmup and timing iterations
|
||||
//
|
||||
|
||||
if (g_timing_iterations > 0)
|
||||
{
|
||||
// Warmup for 1/100 of the timing iterations (minimum of 2)
|
||||
for (int i = 0; i < __NV_STD_MAX(2, (g_timing_iterations + 99) / 100); ++i)
|
||||
{
|
||||
CUDA_PERROR(test_func(
|
||||
g_cublas_handle,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
A.d_data(),
|
||||
B.d_data(),
|
||||
C.d_data(),
|
||||
alpha,
|
||||
beta,
|
||||
stream,
|
||||
false).result);
|
||||
}
|
||||
}
|
||||
|
||||
// Conduct timing iterations
|
||||
double elapsed_ms = 0;
|
||||
gpu_timer timer;
|
||||
timer.start();
|
||||
|
||||
for (int i = 0; i < g_timing_iterations; i++)
|
||||
{
|
||||
CUDA_PERROR(test_func(
|
||||
g_cublas_handle,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
A.d_data(),
|
||||
B.d_data(),
|
||||
C.d_data(),
|
||||
alpha,
|
||||
beta,
|
||||
stream,
|
||||
false).result);
|
||||
}
|
||||
|
||||
timer.stop();
|
||||
elapsed_ms += timer.elapsed_millis();
|
||||
double avg_ms = elapsed_ms / g_timing_iterations;
|
||||
|
||||
// Display performance
|
||||
if (g_timing_iterations > 0)
|
||||
{
|
||||
int64_t num_flops = (2 * int64_t(m) * int64_t(n) * int64_t(k)) + (2 * int64_t(m) * int64_t(n));
|
||||
double gflops_per_sec = double(num_flops) / avg_ms / 1.0e6;
|
||||
|
||||
if (g_schmoo)
|
||||
{
|
||||
if (is_failed)
|
||||
printf("F");
|
||||
|
||||
printf(", %.3f", gflops_per_sec);
|
||||
|
||||
// Sleep for a few milliseconds to cool
|
||||
sleep_millis(10);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Avg runtime: %.3f ms, total flops: %lld, GFLOP/s: %.2f\n",
|
||||
avg_ms,
|
||||
num_flops,
|
||||
gflops_per_sec);
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
return is_failed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute C = (alpha * A * B) + (beta * C)
|
||||
*/
|
||||
template <
|
||||
math_operation_class_t math_op,
|
||||
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
|
||||
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
|
||||
typename value_t, ///< Multiplicand value type (matrices A and B)
|
||||
typename accum_t> ///< Accumulator value type (matrix C and scalars)
|
||||
bool test(
|
||||
int m, ///< Height of C in rows
|
||||
int n, ///< Width of C in columns
|
||||
int k, ///< Width (height) of A (B)
|
||||
accum_t alpha, ///< Multiplicand scalar
|
||||
accum_t beta) ///< Addend scalar
|
||||
{
|
||||
uint64_t flop_base = 1ull << 41;
|
||||
int max_timing_iterations = 10000;
|
||||
int min_timing_iterations = 10;
|
||||
|
||||
bool test_error = false;
|
||||
|
||||
// Scale the number of timing iterations with respect to problem size (if not specified on commandline)
|
||||
if ((g_timing_iterations < 0) || g_schmoo)
|
||||
{
|
||||
uint64_t num_flops = (2 * uint64_t(m) * uint64_t(n) * uint64_t(k)) + (2 * uint64_t(m) * uint64_t(n));
|
||||
g_timing_iterations = (int) ((flop_base / sizeof(value_t)) / num_flops);
|
||||
|
||||
g_timing_iterations = (int) __NV_STD_MIN(max_timing_iterations, g_timing_iterations);
|
||||
g_timing_iterations = (int) __NV_STD_MAX(min_timing_iterations, g_timing_iterations);
|
||||
}
|
||||
|
||||
if (g_schmoo)
|
||||
{
|
||||
printf("%d, %d, %d, %c%c, %d, %d",
|
||||
m, n, k,
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? 'n' : 't',
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? 'n' : 't',
|
||||
m * n,
|
||||
g_timing_iterations);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n------------------------------------------------------------\n");
|
||||
printf("%dx%dx%d, GEMM_%c%c, %d C elements, %d timing iterations\n",
|
||||
m, n, k,
|
||||
(TransformA == matrix_transform_t::NonTranspose) ? 'n' : 't',
|
||||
(TransformB == matrix_transform_t::NonTranspose) ? 'n' : 't',
|
||||
m * n,
|
||||
g_timing_iterations);
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
// CUBLAS
|
||||
test_error |= test<
|
||||
cublas_gemm<gemm::tiling_strategy::Unknown, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
// CUTLASS
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Small, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Medium, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Large, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Tall, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Wide, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
test_error |= test<
|
||||
cutlass_gemm_dispatch<gemm::tiling_strategy::Huge, math_op, TransformA, TransformB, value_t, accum_t>,
|
||||
TransformA,
|
||||
TransformB,
|
||||
value_t,
|
||||
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
|
||||
|
||||
return test_error;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Main
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
/**
|
||||
* Main
|
||||
*/
|
||||
int main(int argc, const char **argv)
|
||||
{
|
||||
//
|
||||
// Problem type (compiler-supplied so we don't compile everything)
|
||||
//
|
||||
|
||||
// Define value_t and accum_t (multiplicand and accumulator types, respectively)
|
||||
#if defined(TEST_SGEMM)
|
||||
typedef float value_t;
|
||||
typedef float accum_t;
|
||||
const math_operation_class_t math_op = math_operation_class_t::scalar;
|
||||
#elif defined(TEST_DGEMM)
|
||||
typedef double value_t;
|
||||
typedef double accum_t;
|
||||
const math_operation_class_t math_op = math_operation_class_t::scalar;
|
||||
#elif defined(TEST_HGEMM)
|
||||
typedef __half value_t;
|
||||
typedef __half accum_t;
|
||||
const math_operation_class_t math_op = math_operation_class_t::scalar;
|
||||
#elif defined(TEST_IGEMM)
|
||||
typedef int8_t value_t;
|
||||
typedef int32_t accum_t;
|
||||
const math_operation_class_t math_op = math_operation_class_t::scalar;
|
||||
#elif defined(TEST_WGEMM)
|
||||
typedef half value_t;
|
||||
typedef float accum_t;
|
||||
const math_operation_class_t math_op = math_operation_class_t::matrix;
|
||||
#else
|
||||
#error Unknown GEMM type requested.
|
||||
#endif
|
||||
|
||||
|
||||
// Define transpose constants
|
||||
#ifdef TRANSPOSE_A
|
||||
static const matrix_transform_t::kind_t TransformA = matrix_transform_t::Transpose;
|
||||
#else
|
||||
static const matrix_transform_t::kind_t TransformA = matrix_transform_t::NonTranspose;
|
||||
#endif
|
||||
|
||||
#ifdef TRANSPOSE_B
|
||||
static const matrix_transform_t::kind_t TransformB = matrix_transform_t::Transpose;
|
||||
#else
|
||||
static const matrix_transform_t::kind_t TransformB = matrix_transform_t::NonTranspose;
|
||||
#endif
|
||||
|
||||
|
||||
//
|
||||
// Commandline parsing
|
||||
//
|
||||
|
||||
// Initialize command line
|
||||
command_line args(argc, argv);
|
||||
|
||||
int m_factor = args.device_prop.multiProcessorCount * 128;
|
||||
int m = round_nearest(4096, m_factor);
|
||||
int k = 4096;
|
||||
int n = 4096;
|
||||
float alpha = 1.0;
|
||||
float beta = 0.0;
|
||||
|
||||
g_device_id = args.device_id;
|
||||
args.get_cmd_line_argument("m", m);
|
||||
args.get_cmd_line_argument("n", n);
|
||||
args.get_cmd_line_argument("k", k);
|
||||
args.get_cmd_line_argument("i", g_timing_iterations);
|
||||
args.get_cmd_line_argument("alpha", alpha);
|
||||
args.get_cmd_line_argument("beta", beta);
|
||||
args.get_cmd_line_argument("schmoo", g_schmoo);
|
||||
|
||||
// Print usage
|
||||
if (args.check_cmd_line_flag("help"))
|
||||
{
|
||||
printf("%s "
|
||||
"[--help] "
|
||||
"[--i=<timing iterations>] "
|
||||
"[--device=<device-id>] "
|
||||
"[--alpha=<alpha> --beta=<beta>] "
|
||||
"[--schmoo=<samples> || --m=<height> --n=<width> --k=<depth>]"
|
||||
"\n", argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// Initialize cuBLAS
|
||||
if (cublasCreate(&g_cublas_handle) != CUBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
fprintf(stderr, "cublasCreate() failed\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool test_error = false;
|
||||
|
||||
if (g_schmoo)
|
||||
{
|
||||
// Run a schmoo of problem sizes
|
||||
printf("M, N, K, transpose, total_flops, timing_iterations, sol_flop/s, cublas_sol, cutlass_small_sol, cutlass_med_sol, cutlass_large_sol, cutlass_tall_sol, cutlass_wide_sol, cutlass_huge_sol\n");
|
||||
|
||||
// Generate power-law distribution from [32, 16384)
|
||||
std::mt19937 gen(0);
|
||||
std::uniform_real_distribution<float> dis(5, 14);
|
||||
for (int i = 0; i < g_schmoo; ++i)
|
||||
{
|
||||
int m = int(pow(float(2), dis(gen)));
|
||||
int n = int(pow(float(2), dis(gen)));
|
||||
int k = int(pow(float(2), dis(gen)));
|
||||
|
||||
// Round m and n to nearest multiple of 32 if < 128, otherwise to the nearest 128
|
||||
m = (m < 128) ?
|
||||
round_nearest(m, 32) :
|
||||
round_nearest(m, 128);
|
||||
n = (n < 128) ?
|
||||
round_nearest(n, 32) :
|
||||
round_nearest(n, 128);
|
||||
|
||||
// Round k to the nearest 16
|
||||
k = (sizeof(value_t) == 1) ?
|
||||
round_nearest(k, 32) :
|
||||
round_nearest(k, 16);
|
||||
|
||||
test_error |= test<math_op, TransformA, TransformB, value_t, accum_t>(
|
||||
m, n, k,
|
||||
from_float<accum_t>(alpha),
|
||||
from_float<accum_t>(beta));
|
||||
|
||||
printf("\n"); fflush(stdout);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Test a single GEMM problem size
|
||||
test_error |= test<math_op, TransformA, TransformB, value_t, accum_t>(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
from_float<accum_t>(alpha),
|
||||
from_float<accum_t>(beta));
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
cublasDestroy(g_cublas_handle);
|
||||
|
||||
return test_error;
|
||||
}
|
||||
|
||||
@ -1,320 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Utility for parsing command line arguments
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/util/debug.h>
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* command_line
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Utility for parsing command line arguments
|
||||
*/
|
||||
struct command_line
|
||||
{
|
||||
|
||||
std::vector<std::string> keys;
|
||||
std::vector<std::string> values;
|
||||
std::vector<std::string> args;
|
||||
int device_id;
|
||||
cudaDeviceProp device_prop;
|
||||
float device_giga_bandwidth;
|
||||
size_t device_free_physmem;
|
||||
size_t device_total_physmem;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
command_line(int argc, const char **argv, int device_id = -1) :
|
||||
keys(10),
|
||||
values(10),
|
||||
device_id(device_id)
|
||||
{
|
||||
using namespace std;
|
||||
|
||||
for (int i = 1; i < argc; i++)
|
||||
{
|
||||
string arg = argv[i];
|
||||
|
||||
if ((arg[0] != '-') || (arg[1] != '-'))
|
||||
{
|
||||
args.push_back(arg);
|
||||
continue;
|
||||
}
|
||||
|
||||
string::size_type pos;
|
||||
string key, val;
|
||||
if ((pos = arg.find('=')) == string::npos) {
|
||||
key = string(arg, 2, arg.length() - 2);
|
||||
val = "";
|
||||
} else {
|
||||
key = string(arg, 2, pos - 2);
|
||||
val = string(arg, pos + 1, arg.length() - 1);
|
||||
}
|
||||
|
||||
keys.push_back(key);
|
||||
values.push_back(val);
|
||||
}
|
||||
|
||||
// Initialize device
|
||||
CUDA_PERROR_EXIT(device_init());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Checks whether a flag "--<flag>" is present in the commandline
|
||||
*/
|
||||
bool check_cmd_line_flag(const char* arg_name)
|
||||
{
|
||||
using namespace std;
|
||||
|
||||
for (int i = 0; i < int(keys.size()); ++i)
|
||||
{
|
||||
if (keys[i] == string(arg_name))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns number of naked (non-flag and non-key-value) commandline parameters
|
||||
*/
|
||||
template <typename value_t>
|
||||
int num_naked_args()
|
||||
{
|
||||
return args.size();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the commandline parameter for a given index (not including flags)
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(int index, value_t &val)
|
||||
{
|
||||
using namespace std;
|
||||
if (index < args.size()) {
|
||||
istringstream str_stream(args[index]);
|
||||
str_stream >> val;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the value specified for a given commandline parameter --<flag>=<value>
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_argument(const char *arg_name, value_t &val)
|
||||
{
|
||||
using namespace std;
|
||||
|
||||
for (int i = 0; i < int(keys.size()); ++i)
|
||||
{
|
||||
if (keys[i] == string(arg_name))
|
||||
{
|
||||
istringstream str_stream(values[i]);
|
||||
str_stream >> val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
|
||||
*/
|
||||
template <typename value_t>
|
||||
void get_cmd_line_arguments(
|
||||
const char *arg_name,
|
||||
std::vector<value_t> &vals,
|
||||
char sep = ',')
|
||||
{
|
||||
using namespace std;
|
||||
|
||||
if (check_cmd_line_flag(arg_name))
|
||||
{
|
||||
// Clear any default values
|
||||
vals.clear();
|
||||
|
||||
// Recover from multi-value string
|
||||
for (int i = 0; i < keys.size(); ++i)
|
||||
{
|
||||
if (keys[i] == string(arg_name))
|
||||
{
|
||||
string val_string(values[i]);
|
||||
istringstream str_stream(val_string);
|
||||
string::size_type old_pos = 0;
|
||||
string::size_type new_pos = 0;
|
||||
|
||||
// Iterate <sep>-delimited values
|
||||
value_t val;
|
||||
while ((new_pos = val_string.find(sep, old_pos)) != string::npos)
|
||||
{
|
||||
if (new_pos != old_pos)
|
||||
{
|
||||
str_stream.width(new_pos - old_pos);
|
||||
str_stream >> val;
|
||||
vals.push_back(val);
|
||||
}
|
||||
|
||||
// skip over delimiter
|
||||
str_stream.ignore(1);
|
||||
old_pos = new_pos + 1;
|
||||
}
|
||||
|
||||
// Read last value
|
||||
str_stream >> val;
|
||||
vals.push_back(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The number of pairs parsed
|
||||
*/
|
||||
int parsed_argc()
|
||||
{
|
||||
return (int) keys.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize device
|
||||
*/
|
||||
cudaError_t device_init()
|
||||
{
|
||||
cudaError_t error = cudaSuccess;
|
||||
|
||||
do
|
||||
{
|
||||
int deviceCount;
|
||||
if (CUDA_PERROR(error = cudaGetDeviceCount(&deviceCount))) break;
|
||||
|
||||
if (deviceCount == 0) {
|
||||
fprintf(stderr, "No devices supporting CUDA.\n");
|
||||
exit(1);
|
||||
}
|
||||
if (device_id < 0)
|
||||
{
|
||||
get_cmd_line_argument("device", device_id);
|
||||
}
|
||||
if ((device_id > deviceCount - 1) || (device_id < 0))
|
||||
{
|
||||
device_id = 0;
|
||||
}
|
||||
|
||||
if (CUDA_PERROR(error = cudaSetDevice(device_id))) break;
|
||||
|
||||
if (CUDA_PERROR(error = cudaMemGetInfo(&device_free_physmem, &device_total_physmem))) break;
|
||||
|
||||
if (CUDA_PERROR(error = cudaGetDeviceProperties(&device_prop, device_id))) break;
|
||||
|
||||
if (device_prop.major < 1) {
|
||||
fprintf(stderr, "Device does not support CUDA.\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000;
|
||||
|
||||
} while (0);
|
||||
|
||||
return error;
|
||||
}
|
||||
|
||||
|
||||
//-------------------------------------------------------------------------
|
||||
// Utility functions
|
||||
//-------------------------------------------------------------------------
|
||||
|
||||
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
||||
static void tokenize(
|
||||
std::vector<std::pair<std::string, std::string> > &tokens,
|
||||
std::string const &str,
|
||||
char delim = ',',
|
||||
char sep = ':')
|
||||
{
|
||||
// Home-built to avoid Boost dependency
|
||||
size_t s_idx = 0;
|
||||
size_t d_idx = std::string::npos;
|
||||
while (s_idx < str.size())
|
||||
{
|
||||
d_idx = str.find_first_of(delim, s_idx);
|
||||
|
||||
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
|
||||
size_t sep_idx = str.find_first_of(sep, s_idx);
|
||||
size_t offset = 1;
|
||||
if (sep_idx == std::string::npos || sep_idx >= end_idx)
|
||||
{
|
||||
sep_idx = end_idx;
|
||||
offset = 0;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> item(
|
||||
str.substr(s_idx, sep_idx - s_idx),
|
||||
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
|
||||
|
||||
tokens.push_back(item);
|
||||
s_idx = end_idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
|
||||
static void tokenize(
|
||||
std::vector<std::string > &tokens,
|
||||
std::string const &str,
|
||||
char delim = ',',
|
||||
char sep = ':')
|
||||
{
|
||||
std::vector<std::pair<std::string, std::string> > token_pairs;
|
||||
tokenize(token_pairs, str, delim, sep);
|
||||
for (auto const &tok : token_pairs)
|
||||
{
|
||||
tokens.push_back(tok.first);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,91 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief C++ exception semantics for CUDA error codes
|
||||
*/
|
||||
|
||||
#include <iosfwd>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/// C++ exception wrapper for CUDA \p cudaError_t
|
||||
class cuda_exception : public std::exception
|
||||
{
|
||||
public:
|
||||
|
||||
/// Constructor
|
||||
cuda_exception(
|
||||
const char *msg = "",
|
||||
cudaError_t err = cudaErrorUnknown)
|
||||
:
|
||||
msg(msg), err(err)
|
||||
{}
|
||||
|
||||
/// Returns the explanatory string
|
||||
const char *what() const noexcept
|
||||
{
|
||||
return msg;
|
||||
}
|
||||
|
||||
/// Returns the underlying CUDA \p cudaError_t
|
||||
cudaError_t cudaError() const
|
||||
{
|
||||
return err;
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
|
||||
/// Explanatory string
|
||||
const char *msg;
|
||||
|
||||
/// Underlying CUDA \p cudaError_t
|
||||
cudaError_t err;
|
||||
};
|
||||
|
||||
|
||||
/// Writes a cudaError_t to an output stream
|
||||
inline std::ostream & operator<<(std::ostream &out, cudaError_t result)
|
||||
{
|
||||
return out << cudaGetErrorString(result);
|
||||
}
|
||||
|
||||
/// Writes a cuda_exception instance to an output stream
|
||||
inline std::ostream & operator<<(std::ostream &out, cuda_exception const &e)
|
||||
{
|
||||
return out << e.what() << ": " << e.cudaError();
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,231 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Utilities for interacting with the opaque CUDA __half type
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <iosfwd>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* half_t
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* Host-based fp16 data type compatible and convertible with __half
|
||||
*/
|
||||
struct half_t
|
||||
{
|
||||
uint16_t __x;
|
||||
|
||||
/// Constructor from __half
|
||||
half_t(const __half &other)
|
||||
{
|
||||
__x = reinterpret_cast<const uint16_t&>(other);
|
||||
}
|
||||
|
||||
/// Constructor from integer
|
||||
half_t(int a)
|
||||
{
|
||||
*this = half_t(float(a));
|
||||
}
|
||||
|
||||
|
||||
/// Constructor from float
|
||||
half_t(float a)
|
||||
{
|
||||
uint32_t ia = *reinterpret_cast<uint32_t*>(&a);
|
||||
uint16_t ir;
|
||||
|
||||
ir = (ia >> 16) & 0x8000;
|
||||
|
||||
if ((ia & 0x7f800000) == 0x7f800000)
|
||||
{
|
||||
if ((ia & 0x7fffffff) == 0x7f800000)
|
||||
{
|
||||
ir |= 0x7c00; /* infinity */
|
||||
}
|
||||
else
|
||||
{
|
||||
ir = 0x7fff; /* canonical NaN */
|
||||
}
|
||||
}
|
||||
else if ((ia & 0x7f800000) >= 0x33000000)
|
||||
{
|
||||
int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
|
||||
if (shift > 15)
|
||||
{
|
||||
ir |= 0x7c00; /* infinity */
|
||||
}
|
||||
else
|
||||
{
|
||||
ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */
|
||||
if (shift < -14)
|
||||
{ /* denormal */
|
||||
ir |= ia >> (-1 - shift);
|
||||
ia = ia << (32 - (-1 - shift));
|
||||
}
|
||||
else
|
||||
{ /* normal */
|
||||
ir |= ia >> (24 - 11);
|
||||
ia = ia << (32 - (24 - 11));
|
||||
ir = ir + ((14 + shift) << 10);
|
||||
}
|
||||
/* IEEE-754 round to nearest of even */
|
||||
if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
|
||||
{
|
||||
ir++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this->__x = ir;
|
||||
}
|
||||
|
||||
/// Cast to __half
|
||||
operator __half() const
|
||||
{
|
||||
return reinterpret_cast<const __half&>(__x);
|
||||
}
|
||||
|
||||
/// Cast to float
|
||||
operator float() const
|
||||
{
|
||||
int sign = ((this->__x >> 15) & 1);
|
||||
int exp = ((this->__x >> 10) & 0x1f);
|
||||
int mantissa = (this->__x & 0x3ff);
|
||||
uint32_t f = 0;
|
||||
|
||||
if (exp > 0 && exp < 31)
|
||||
{
|
||||
// normal
|
||||
exp += 112;
|
||||
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
||||
}
|
||||
else if (exp == 0)
|
||||
{
|
||||
if (mantissa)
|
||||
{
|
||||
// subnormal
|
||||
exp += 113;
|
||||
while ((mantissa & (1 << 10)) == 0)
|
||||
{
|
||||
mantissa <<= 1;
|
||||
exp--;
|
||||
}
|
||||
mantissa &= 0x3ff;
|
||||
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
||||
}
|
||||
else
|
||||
{
|
||||
// zero
|
||||
f = 0;
|
||||
}
|
||||
}
|
||||
else if (exp == 31)
|
||||
{
|
||||
if (mantissa)
|
||||
{
|
||||
f = 0x7fffffff; // not a number
|
||||
}
|
||||
else
|
||||
{
|
||||
f = (0xff << 23) | (sign << 31); // inf
|
||||
}
|
||||
}
|
||||
return *reinterpret_cast<float const *>(&f);
|
||||
}
|
||||
|
||||
|
||||
/// Get raw storage
|
||||
uint16_t raw()
|
||||
{
|
||||
return this->__x;
|
||||
}
|
||||
|
||||
/// Assignment by sum
|
||||
bool operator ==(const half_t &other)
|
||||
{
|
||||
return (this->__x == other.__x);
|
||||
}
|
||||
|
||||
/// Increment
|
||||
half_t& operator +=(const half_t &rhs)
|
||||
{
|
||||
*this = half_t(float(*this) + float(rhs));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Decrement
|
||||
half_t& operator -=(const half_t &rhs)
|
||||
{
|
||||
*this = half_t(float(*this) - float(rhs));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Multiply
|
||||
half_t operator*(const half_t &other)
|
||||
{
|
||||
return half_t(float(*this) * float(other));
|
||||
}
|
||||
|
||||
/// Multiply
|
||||
half_t operator+(const half_t &other)
|
||||
{
|
||||
return half_t(float(*this) + float(other));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* I/O stream overloads
|
||||
******************************************************************************/
|
||||
|
||||
/// Insert formatted \p half_t into the output stream
|
||||
std::ostream& operator<<(std::ostream &out, const half_t &x)
|
||||
{
|
||||
out << (float)x;
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/// Insert formatted \p __half into the output stream
|
||||
std::ostream& operator<<(std::ostream &out, const __half &x)
|
||||
{
|
||||
return out << half_t(x);
|
||||
}
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,503 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* Matrix data structure providing basic CPU-based algorithms and
|
||||
* operations that can be cloned and synchronized in GPU device memory
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#include <cutlass/util/debug.h>
|
||||
#include "../cutlass/util/matrix_transform.h"
|
||||
#include "half.h"
|
||||
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/**
|
||||
* \brief Matrix data structure providing basic CPU-based algorithms and
|
||||
* operations that be synchronized with a GPU-based replica
|
||||
*/
|
||||
template <typename value_t>
|
||||
struct matrix
|
||||
{
|
||||
// Host value type (must be convertible to/from value_t)
|
||||
typedef typename nv_std::conditional<
|
||||
(nv_std::is_same<value_t, __half>::value), // If (value_t == __half) ...
|
||||
half_t, // ... use half_t internally for host storage, else...
|
||||
value_t>::type // ... use value_t directly
|
||||
host_value_t;
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Data members
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
private:
|
||||
|
||||
/// M dimension (height in rows)
|
||||
int _m;
|
||||
|
||||
/// N dimension (width in columns)
|
||||
int _n;
|
||||
|
||||
/// Data array on host
|
||||
std::vector<host_value_t> _h_data;
|
||||
|
||||
/// Clone of data array on GPU device
|
||||
value_t *_d_data;
|
||||
|
||||
/// GPU Device identifier that clone synchronizes with
|
||||
int _device_id;
|
||||
|
||||
public:
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Lifetime and synchronization
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Constructor: zero-initializes the matrix.
|
||||
*/
|
||||
matrix(
|
||||
int m, ///< Height of the matrix in rows
|
||||
int n) ///< Width of the matrix in columns
|
||||
:
|
||||
_m(m),
|
||||
_n(n),
|
||||
_d_data(NULL),
|
||||
_device_id(0)
|
||||
{
|
||||
_h_data.resize(_m * _n, 0);
|
||||
CUDA_PERROR_EXIT(cudaMalloc((void ** )&_d_data, sizeof(value_t) * _m * _n));
|
||||
CUDA_PERROR_EXIT(cudaGetDevice(&_device_id));
|
||||
}
|
||||
|
||||
/// Destructor
|
||||
~matrix()
|
||||
{
|
||||
if (_d_data)
|
||||
{
|
||||
CUDA_PERROR_EXIT(cudaFree(_d_data));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Synchronize the GPU-based replica with the current host-based matrix data
|
||||
*/
|
||||
void sync_device()
|
||||
{
|
||||
size_t bytes = _m * _n * sizeof(value_t);
|
||||
CUDA_PERROR_EXIT(cudaMemcpy(_d_data, &_h_data[0], bytes, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Synchronize the host-based replica with the current GPU-based matrix data
|
||||
*/
|
||||
void sync_host()
|
||||
{
|
||||
size_t bytes = _m * _n * sizeof(value_t);
|
||||
CUDA_PERROR_EXIT(cudaMemcpy(&_h_data[0], _d_data, bytes, cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Inspectors
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Return the height of the matrix, subject to the optional \p transpose_op
|
||||
*/
|
||||
int height(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
|
||||
{
|
||||
switch (transpose_op)
|
||||
{
|
||||
case matrix_transform_t::NonTranspose : return _m;
|
||||
case matrix_transform_t::Transpose : return _n;
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return the width of the matrix, subject to the optional \p transpose_op
|
||||
*/
|
||||
int width(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
|
||||
{
|
||||
switch (transpose_op)
|
||||
{
|
||||
case matrix_transform_t::NonTranspose : return _n;
|
||||
case matrix_transform_t::Transpose : return _m;
|
||||
default: return -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return item at (x, y) coordinate of matrix, subject to the optional \p transform op
|
||||
*/
|
||||
host_value_t get(
|
||||
int x,
|
||||
int y,
|
||||
matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
|
||||
{
|
||||
switch (transpose_op)
|
||||
{
|
||||
case matrix_transform_t::NonTranspose : return _h_data[y + (x * _m)];
|
||||
case matrix_transform_t::Transpose : return _h_data[x + (y * _m)];
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return the distance (in items) within memory between elements of two
|
||||
* consecutive columns which have the same row index, subject to the optional \p transform op
|
||||
*/
|
||||
int leading_dim(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
|
||||
{
|
||||
switch (transpose_op)
|
||||
{
|
||||
case matrix_transform_t::NonTranspose : return _m;
|
||||
case matrix_transform_t::Transpose : return _n;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get host data pointer
|
||||
*/
|
||||
value_t* h_data()
|
||||
{
|
||||
return _h_data.data();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get host data pointer
|
||||
*/
|
||||
value_t const* h_data() const
|
||||
{
|
||||
return _h_data.data();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get device data pointer
|
||||
*/
|
||||
value_t const* d_data() const
|
||||
{
|
||||
return _d_data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get device data pointer
|
||||
*/
|
||||
value_t * d_data()
|
||||
{
|
||||
return _d_data;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Initialization
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Initialize matrix values with a 2D "ramp" defined as
|
||||
* <tt>values(x, y) = (y * rs) + (x * cs)</tt>
|
||||
*/
|
||||
void fill_ramp(
|
||||
host_value_t rs,
|
||||
host_value_t cs)
|
||||
{
|
||||
for (int x = 0; x < _n; x++)
|
||||
{
|
||||
for (int y = 0; y < _m; y++)
|
||||
{
|
||||
_h_data[y + (x * _m)] = host_value_t((y * rs) + (x * cs));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initialize matrix values such that all the elements of the principal diagonal
|
||||
* are ones and all other elements are zeros
|
||||
*/
|
||||
void fill_identity()
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
_h_data[i + j * _m] = host_value_t(i == j ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Initialize matrix values using the random number \p generator. The
|
||||
* \p generator reference is assumed to be a nullary functor that returns
|
||||
* values convertible to the matrix \p value_t.
|
||||
*/
|
||||
template <typename T>
|
||||
void fill_random(T & generator)
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
_h_data[i + j * _m] = (value_t) generator();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Element-wise matrix addition
|
||||
*/
|
||||
matrix & operator+=(matrix const &mat)
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
_h_data[i + j * _m] += mat._h_data[i + j * _m];
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-wise matrix subtraction
|
||||
*/
|
||||
matrix & operator-=(matrix const &mat)
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
_h_data[i + j * _m] -= mat._h_data[i + j * _m];
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Output
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Prints matrix in CSV to output stream
|
||||
*/
|
||||
template <typename _hv_t>
|
||||
std::ostream & write_matrix(std::ostream &out, _hv_t)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
out << (j ? "," : "") << _h_data[i + j * _m];
|
||||
}
|
||||
out << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Prints matrix in CSV to output stream
|
||||
*/
|
||||
std::ostream & write_matrix(std::ostream &out, int8_t)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
out << (j ? "," : "") << int32_t(_h_data[i + j * _m]);
|
||||
}
|
||||
out << "\n";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Prints matrix in CSV to output stream
|
||||
*/
|
||||
std::ostream & write_matrix(std::ostream &out)
|
||||
{
|
||||
return write_matrix(out, _h_data[0]);
|
||||
}
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Floating point "almost-equal" utilities
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
static bool almost_equal_ulps(half_t a, half_t b, int max_ulps)
|
||||
{
|
||||
if (a == b)
|
||||
return true;
|
||||
|
||||
int32_t int_diff = abs(a.raw() - b.raw());
|
||||
if (int_diff <= max_ulps)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
static bool almost_equal_ulps(float a, float b, int max_ulps)
|
||||
{
|
||||
if (a == b)
|
||||
return true;
|
||||
int32_t int_diff = abs(*(int32_t*)&a - *(int32_t*)&b);
|
||||
if (int_diff <= max_ulps)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
static bool almost_equal_ulps(double a, double b, int max_ulps)
|
||||
{
|
||||
if (a == b)
|
||||
return true;
|
||||
int64_t int_diff = abs(*(int64_t*)&a - *(int64_t*)&b);
|
||||
if (int_diff <= max_ulps)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool almost_equal_ulps(int32_t a, int32_t b, int max_ulps)
|
||||
{
|
||||
return (a == b);
|
||||
}
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// matrix operations
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
|
||||
/**
|
||||
* Returns matrix equality
|
||||
*/
|
||||
bool operator==(const matrix<value_t> &mat) const
|
||||
{
|
||||
int max_ulps = 30;
|
||||
|
||||
if (_m != mat._m || _n != mat._n)
|
||||
{
|
||||
fprintf(stderr, "Error: dimension mismatch during matrix comparison.\n"); exit(1);
|
||||
}
|
||||
|
||||
for (int j = 0; j < _n; j++)
|
||||
{
|
||||
for (int i = 0; i < _m; i++)
|
||||
{
|
||||
if (!almost_equal_ulps(_h_data[i + j * _m], mat._h_data[i + j * _m], max_ulps))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns matrix inequality
|
||||
*/
|
||||
bool operator!=(const matrix<value_t> &mat) const
|
||||
{
|
||||
return !(*this == mat);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Computes this = (alpha * op(A) * op(B)) + (beta * this), specialized for gemm_nn
|
||||
*/
|
||||
template <typename multiplicand_t>
|
||||
void gemm(
|
||||
matrix_transform_t transform_a,
|
||||
matrix_transform_t transform_b,
|
||||
host_value_t alpha,
|
||||
const matrix<multiplicand_t> &A,
|
||||
const matrix<multiplicand_t> &B,
|
||||
host_value_t beta)
|
||||
{
|
||||
// Sanity check dimensions
|
||||
if ((_m != A.height(transform_a)) ||
|
||||
(_n != B.width(transform_b)) ||
|
||||
(A.width(transform_a) != B.height(transform_b)))
|
||||
{
|
||||
fprintf(stderr, "Error: dimension mismatch during gemm.\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
int M = A.height(transform_a);
|
||||
int K = A.width(transform_a);
|
||||
int N = B.width(transform_b);
|
||||
|
||||
// Even the host-side implementation utilizes a blocking structure to improve
|
||||
// verification performance
|
||||
int DimBlockM = (M % 16 == 0) ? 16 : 1;
|
||||
int DimBlockN = (N % 16 == 0) ? 16 : 1;
|
||||
|
||||
for (int i = 0; i < M; i += DimBlockM)
|
||||
{
|
||||
for (int j = 0; j < N; j += DimBlockN)
|
||||
{
|
||||
for (int block_y = 0; block_y < DimBlockM; block_y++)
|
||||
{
|
||||
for (int block_x = 0; block_x < DimBlockN; block_x++)
|
||||
{
|
||||
int y = i + block_y;
|
||||
int x = j + block_x;
|
||||
|
||||
host_value_t accum(0);
|
||||
for (int k = 0; k < K; k++)
|
||||
{
|
||||
accum += host_value_t(A.get(k, y, transform_a)) * host_value_t(B.get(x, k, transform_b));
|
||||
}
|
||||
|
||||
_h_data[y + x * M] = (alpha * accum) + (beta * _h_data[y + x * M]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,107 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* GPU kernel timer
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cutlass/util/debug.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* gpu_timer
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* GPU event-based timer
|
||||
*/
|
||||
struct gpu_timer
|
||||
{
|
||||
cudaEvent_t _start;
|
||||
cudaEvent_t _stop;
|
||||
|
||||
gpu_timer()
|
||||
{
|
||||
CUDA_PERROR_EXIT(cudaEventCreate(&_start));
|
||||
CUDA_PERROR_EXIT(cudaEventCreate(&_stop));
|
||||
}
|
||||
|
||||
~gpu_timer()
|
||||
{
|
||||
CUDA_PERROR_EXIT(cudaEventDestroy(_start));
|
||||
CUDA_PERROR_EXIT(cudaEventDestroy(_stop));
|
||||
}
|
||||
|
||||
void start()
|
||||
{
|
||||
CUDA_PERROR_EXIT(cudaEventRecord(_start, 0));
|
||||
}
|
||||
|
||||
void stop()
|
||||
{
|
||||
CUDA_PERROR_EXIT(cudaEventRecord(_stop, 0));
|
||||
}
|
||||
|
||||
float elapsed_millis()
|
||||
{
|
||||
float elapsed = 0.0;
|
||||
CUDA_PERROR_EXIT(cudaEventSynchronize(_stop));
|
||||
CUDA_PERROR_EXIT(cudaEventElapsedTime(&elapsed, _start, _stop));
|
||||
return elapsed;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* sleep_millis
|
||||
******************************************************************************/
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
|
||||
void sleep_millis(unsigned milliseconds)
|
||||
{
|
||||
Sleep(milliseconds);
|
||||
}
|
||||
#else
|
||||
#include <unistd.h>
|
||||
|
||||
void sleep_millis(unsigned milliseconds)
|
||||
{
|
||||
usleep(milliseconds * 1000); // takes microseconds
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
@ -1,163 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* \file
|
||||
* \brief Utilities for converting between types and assessing traits
|
||||
*/
|
||||
|
||||
#include "half.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
/******************************************************************************
|
||||
* Float conversion utilities
|
||||
******************************************************************************/
|
||||
|
||||
/// Convert float to value type
|
||||
template <typename value_t>
|
||||
value_t from_float(float val)
|
||||
{
|
||||
return value_t(val);
|
||||
}
|
||||
|
||||
/// Convert float to value type (__half specialization)
|
||||
template <>
|
||||
__half from_float<__half>(float val)
|
||||
{
|
||||
return half_t(val);
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Type conversion utilities
|
||||
******************************************************************************/
|
||||
|
||||
/// Member \p type is defined as the signed integer type having the same size as \p T
|
||||
template <typename T>
|
||||
struct integer_alias;
|
||||
|
||||
template <>
|
||||
struct integer_alias<int8_t> {
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct integer_alias<half_t> {
|
||||
using type = int16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct integer_alias<__half> {
|
||||
using type = int16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct integer_alias<float> {
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct integer_alias<int> {
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct integer_alias<double> {
|
||||
using type = int64_t;
|
||||
};
|
||||
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Type-info utilities
|
||||
******************************************************************************/
|
||||
|
||||
/// Returns a string to prefix 'gemm' to construct CUBLAS-like kernel names
|
||||
template <math_operation_class_t math_op, typename value_t, typename accum_t> char const *to_prefix_string();
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::scalar, half_t, half_t>() {
|
||||
return "H";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::scalar, __half, __half>() {
|
||||
return "H";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::scalar, float, float>() {
|
||||
return "S";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::matrix, __half, __half>() {
|
||||
return "WmmaH";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::matrix, __half, float>() {
|
||||
return "WmmaS";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::scalar, double, double>() {
|
||||
return "D";
|
||||
}
|
||||
|
||||
template <> char const *to_prefix_string<math_operation_class_t::scalar, int8_t, int32_t>() {
|
||||
return "I";
|
||||
}
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Maps value_t to the minimum vector size used to load operand
|
||||
******************************************************************************/
|
||||
|
||||
template <typename T>
|
||||
struct operand_load_type;
|
||||
|
||||
template <>
|
||||
struct operand_load_type<int8_t> { using type = int32_t; };
|
||||
|
||||
template <typename T>
|
||||
struct operand_load_type { using type = T; };
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Minimum alignment requirement, if any, determined from value_t.
|
||||
******************************************************************************/
|
||||
|
||||
template <typename value_t>
|
||||
struct gemm_alignment_requirement;
|
||||
|
||||
template <>
|
||||
struct gemm_alignment_requirement<uint8_t> { static const int value = 4; };
|
||||
|
||||
template <typename value_t>
|
||||
struct gemm_alignment_requirement { static const int value = 0; };
|
||||
|
||||
|
||||
|
||||
} // namespace cutlass
|
||||
1
docs/_config.yml
Normal file
1
docs/_config.yml
Normal file
@ -0,0 +1 @@
|
||||
theme: jekyll-theme-minimal
|
||||
378
docs/annotated.html
Normal file
378
docs/annotated.html
Normal file
@ -0,0 +1,378 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.14"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Cutlass: Class List</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" async src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">Cutlass
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.14 -->
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
/* @license-end */
|
||||
</script>
|
||||
<script type="text/javascript" src="menudata.js"></script>
|
||||
<script type="text/javascript" src="menu.js"></script>
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
$(function() {
|
||||
initMenu('',true,false,'search.php','Search');
|
||||
$(document).ready(function() { init_search(); });
|
||||
});
|
||||
/* @license-end */</script>
|
||||
<div id="main-nav"></div>
|
||||
</div><!-- top -->
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div class="header">
|
||||
<div class="headertitle">
|
||||
<div class="title">Class List</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
<div class="textblock">Here are the classes, structs, unions and interfaces with brief descriptions:</div><div class="directory">
|
||||
<div class="levels">[detail level <span onclick="javascript:toggleLevel(1);">1</span><span onclick="javascript:toggleLevel(2);">2</span><span onclick="javascript:toggleLevel(3);">3</span><span onclick="javascript:toggleLevel(4);">4</span>]</div><table class="directory">
|
||||
<tr id="row_0_" class="even"><td class="entry"><span style="width:0px;display:inline-block;"> </span><span id="arr_0_" class="arrow" onclick="toggleFolder('0_')">▼</span><span class="icona"><span class="icon">N</span></span><a class="el" href="namespacecutlass.html" target="_self">cutlass</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_0_" class="arrow" onclick="toggleFolder('0_0_')">▶</span><span class="icona"><span class="icon">N</span></span><a class="el" href="namespacecutlass_1_1gemm.html" target="_self">gemm</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_0_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_0_" class="arrow" onclick="toggleFolder('0_0_0_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ClearAccumulators.html" target="_self">ClearAccumulators</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_0_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ClearAccumulators_1_1SharedStorage.html" target="_self">SharedStorage</a></td><td class="desc">The shared storage </td></tr>
|
||||
<tr id="row_0_0_1_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1DgemmConfig.html" target="_self">DgemmConfig</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_2_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1DgemmTraits.html" target="_self">DgemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_3_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1FragmentMultiplyAdd.html" target="_self">FragmentMultiplyAdd</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_4_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1FragmentMultiplyAdd_3_01half_01_4.html" target="_self">FragmentMultiplyAdd< half ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_5_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_5_" class="arrow" onclick="toggleFolder('0_0_5_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1Gemm.html" target="_self">Gemm</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_5_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1Gemm_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_6_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmConfig.html" target="_self">GemmConfig</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_7_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmDesc.html" target="_self">GemmDesc</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_8_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmEpilogue.html" target="_self">GemmEpilogue</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_9_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_9_" class="arrow" onclick="toggleFolder('0_0_9_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmEpilogueTraits.html" target="_self">GemmEpilogueTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_9_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmEpilogueTraits_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_9_1_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmEpilogueTraits_1_1SharedStorage.html" target="_self">SharedStorage</a></td><td class="desc">The shared memory to swizzle the data in the epilogue </td></tr>
|
||||
<tr id="row_0_0_9_2_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1gemm_1_1GemmEpilogueTraits_1_1StreamSharedStorage.html" target="_self">StreamSharedStorage</a></td><td class="desc">The shared memory storage to exchange data </td></tr>
|
||||
<tr id="row_0_0_10_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmEpilogueTraitsHelper.html" target="_self">GemmEpilogueTraitsHelper</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_11_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_11_" class="arrow" onclick="toggleFolder('0_0_11_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalIteratorAb.html" target="_self">GemmGlobalIteratorAb</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_11_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalIteratorAb_1_1Params.html" target="_self">Params</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_12_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_12_" class="arrow" onclick="toggleFolder('0_0_12_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalIteratorCd.html" target="_self">GemmGlobalIteratorCd</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_12_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalIteratorCd_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_13_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_13_" class="arrow" onclick="toggleFolder('0_0_13_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalTileCdTraits.html" target="_self">GemmGlobalTileCdTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_13_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalTileCdTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_14_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_14_" class="arrow" onclick="toggleFolder('0_0_14_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalTileTraits.html" target="_self">GemmGlobalTileTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_14_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmGlobalTileTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_15_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmMultiplicandTraits.html" target="_self">GemmMultiplicandTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_16_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmOperandTraitsAb.html" target="_self">GemmOperandTraitsAb</a></td><td class="desc">Helper to describe attributes of GEMM matrix operands </td></tr>
|
||||
<tr id="row_0_0_17_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_17_" class="arrow" onclick="toggleFolder('0_0_17_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileATraits.html" target="_self">GemmSharedLoadTileATraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_17_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileATraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_18_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_18_" class="arrow" onclick="toggleFolder('0_0_18_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileBTraits.html" target="_self">GemmSharedLoadTileBTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_18_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileBTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_19_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_19_" class="arrow" onclick="toggleFolder('0_0_19_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileDTraits.html" target="_self">GemmSharedLoadTileDTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_19_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedLoadTileDTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_20_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_20_" class="arrow" onclick="toggleFolder('0_0_20_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreTileAbTraits.html" target="_self">GemmSharedStoreTileAbTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_20_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreTileAbTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_21_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_21_" class="arrow" onclick="toggleFolder('0_0_21_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreTileDTraits.html" target="_self">GemmSharedStoreTileDTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_21_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreTileDTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_22_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_22_" class="arrow" onclick="toggleFolder('0_0_22_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreWithSkewTileAbTraits.html" target="_self">GemmSharedStoreWithSkewTileAbTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_22_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmSharedStoreWithSkewTileAbTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_23_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperA.html" target="_self">GemmTileTraitsHelperA</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_24_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperA_3_01MatrixLayout_1_1kColumnMajor_00_01GemmConfig___01_4.html" target="_self">GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_25_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperA_3_01MatrixLayout_1_1kRowMajor_00_01GemmConfig___01_4.html" target="_self">GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_26_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperB.html" target="_self">GemmTileTraitsHelperB</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_27_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperB_3_01MatrixLayout_1_1kColumnMajor_00_01GemmConfig___01_4.html" target="_self">GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_28_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTileTraitsHelperB_3_01MatrixLayout_1_1kRowMajor_00_01GemmConfig___01_4.html" target="_self">GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_29_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_29_" class="arrow" onclick="toggleFolder('0_0_29_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTraits.html" target="_self">GemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_29_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTraits_1_1GlobalLoadStream.html" target="_self">GlobalLoadStream</a></td><td class="desc">Assemble the global load streams for A/B </td></tr>
|
||||
<tr id="row_0_0_29_1_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTraits_1_1MainLoopSharedStorage.html" target="_self">MainLoopSharedStorage</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_29_2_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTraits_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_29_3_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GemmTraits_1_1SharedLoadStream.html" target="_self">SharedLoadStream</a></td><td class="desc">Assemble the shared load stream for A/B </td></tr>
|
||||
<tr id="row_0_0_29_4_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1gemm_1_1GemmTraits_1_1SharedStorage.html" target="_self">SharedStorage</a></td><td class="desc">The storage in shared memory </td></tr>
|
||||
<tr id="row_0_0_29_5_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1gemm_1_1GemmTraits_1_1StreamSharedStorage.html" target="_self">StreamSharedStorage</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_30_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GetExtent.html" target="_self">GetExtent</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_31_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GetExtent_3_01GemmOperand_1_1kA_00_01Tile___01_4.html" target="_self">GetExtent< GemmOperand::kA, Tile_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_32_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GetExtent_3_01GemmOperand_1_1kB_00_01Tile___01_4.html" target="_self">GetExtent< GemmOperand::kB, Tile_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_33_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GlobalLoadStream.html" target="_self">GlobalLoadStream</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_34_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_34_" class="arrow" onclick="toggleFolder('0_0_34_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GlobalLoadStreamBase.html" target="_self">GlobalLoadStreamBase</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_34_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1GlobalLoadStreamBase_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_34_1_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1gemm_1_1GlobalLoadStreamBase_1_1SharedStorage.html" target="_self">SharedStorage</a></td><td class="desc">The storage in shared memory needed by that stream </td></tr>
|
||||
<tr id="row_0_0_35_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmConfig.html" target="_self">HgemmConfig</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_36_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_36_" class="arrow" onclick="toggleFolder('0_0_36_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmCrosswiseGlobalTileTraits.html" target="_self">HgemmCrosswiseGlobalTileTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_36_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmCrosswiseGlobalTileTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_37_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmSwizzle.html" target="_self">HgemmSwizzle</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_38_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTileTraitsHelperA.html" target="_self">HgemmTileTraitsHelperA</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_39_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTileTraitsHelperA_3_01MatrixLayout_1_1kRowMajor_00_01GemmConfig___01_4.html" target="_self">HgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_40_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTileTraitsHelperB.html" target="_self">HgemmTileTraitsHelperB</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_41_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTileTraitsHelperB_3_01MatrixLayout_1_1kColumnMajor_00_01GemmConfig___01_4.html" target="_self">HgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_42_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTraits.html" target="_self">HgemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_43_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTraitsHelper.html" target="_self">HgemmTraitsHelper</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_44_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerA.html" target="_self">HgemmTransformerA</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_45_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerA_3_01MatrixLayout_1_1kColumnMajor_00_01Iterator___01_4.html" target="_self">HgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_46_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerA_3_01MatrixLayout_1_1kRowMajor_00_01Iterator___01_4.html" target="_self">HgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_47_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerB.html" target="_self">HgemmTransformerB</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_48_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerB_3_01MatrixLayout_1_1kColumnMajor_00_01Iterator___01_4.html" target="_self">HgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_49_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1HgemmTransformerB_3_01MatrixLayout_1_1kRowMajor_00_01Iterator___01_4.html" target="_self">HgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_50_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IdentityBlockSwizzle.html" target="_self">IdentityBlockSwizzle</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_51_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmConfig.html" target="_self">IgemmConfig</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_52_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmConfig_3_01OutputTile___00_01int8__t_00_01AccumulatorsPerThread___01_4.html" target="_self">IgemmConfig< OutputTile_, int8_t, AccumulatorsPerThread_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_53_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_53_" class="arrow" onclick="toggleFolder('0_0_53_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmContiguousGlobalTileTraits.html" target="_self">IgemmContiguousGlobalTileTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_53_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmContiguousGlobalTileTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_0_54_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogue.html" target="_self">IgemmEpilogue</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_55_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogue_3_01GemmEpilogueTraits___00_01true_01_4.html" target="_self">IgemmEpilogue< GemmEpilogueTraits_, true ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_56_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogueScalar.html" target="_self">IgemmEpilogueScalar</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_57_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogueScalar_3_01int_01_4.html" target="_self">IgemmEpilogueScalar< int ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_58_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogueTraits.html" target="_self">IgemmEpilogueTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_59_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmEpilogueTraitsHelper.html" target="_self">IgemmEpilogueTraitsHelper</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_60_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmFloatToInt8Converter.html" target="_self">IgemmFloatToInt8Converter</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_61_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmGlobalLoadTransformer.html" target="_self">IgemmGlobalLoadTransformer</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_62_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmGlobalLoadTransformer_3_01Fragment_3_01int8__t_00_01kElements___01_4_00_01float_01_4.html" target="_self">IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_63_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmGlobalStoreTransformer.html" target="_self">IgemmGlobalStoreTransformer</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_64_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmGlobalStoreTransformer_3_01float_00_01Fragment_3_01int8__t_00_01kElements___01_4_01_4.html" target="_self">IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_65_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmInt8ToFloatConverter.html" target="_self">IgemmInt8ToFloatConverter</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_66_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmSharedStoreTransformer.html" target="_self">IgemmSharedStoreTransformer</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_67_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmSwizzle.html" target="_self">IgemmSwizzle</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_68_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTileTraitsHelperA.html" target="_self">IgemmTileTraitsHelperA</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_69_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTileTraitsHelperA_3_01MatrixLayout_1_1kColumnMajor_00_01GemmConfig___01_4.html" target="_self">IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_70_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTileTraitsHelperB.html" target="_self">IgemmTileTraitsHelperB</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_71_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTileTraitsHelperB_3_01MatrixLayout_1_1kRowMajor_00_01GemmConfig___01_4.html" target="_self">IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_72_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTraits.html" target="_self">IgemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_73_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTraitsHelper.html" target="_self">IgemmTraitsHelper</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_74_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerA.html" target="_self">IgemmTransformerA</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_75_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerA_3_01MatrixLayout_1_1kColumnMajor_00_01Iterator___01_4.html" target="_self">IgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_76_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerA_3_01MatrixLayout_1_1kRowMajor_00_01Iterator___01_4.html" target="_self">IgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_77_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerB.html" target="_self">IgemmTransformerB</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_78_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerB_3_01MatrixLayout_1_1kColumnMajor_00_01Iterator___01_4.html" target="_self">IgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_79_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1IgemmTransformerB_3_01MatrixLayout_1_1kRowMajor_00_01Iterator___01_4.html" target="_self">IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_80_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_80_" class="arrow" onclick="toggleFolder('0_0_80_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1LinearScaling.html" target="_self">LinearScaling</a></td><td class="desc">Functor to compute linear combination of fragments </td></tr>
|
||||
<tr id="row_0_0_80_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1LinearScaling_1_1Params.html" target="_self">Params</a></td><td class="desc">The parameters </td></tr>
|
||||
<tr id="row_0_0_81_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ProjectOperand.html" target="_self">ProjectOperand</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_82_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ProjectOperand_3_01GemmOperand_1_1kA_00_01Kstrided_01_4.html" target="_self">ProjectOperand< GemmOperand::kA, Kstrided ></a></td><td class="desc">Project A operand - (0, K, M) </td></tr>
|
||||
<tr id="row_0_0_83_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ProjectOperand_3_01GemmOperand_1_1kB_00_01Kstrided_01_4.html" target="_self">ProjectOperand< GemmOperand::kB, Kstrided ></a></td><td class="desc">Project B operand - (0, K, N) </td></tr>
|
||||
<tr id="row_0_0_84_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ProjectOperand_3_01GemmOperand_1_1kC_00_01true_01_4.html" target="_self">ProjectOperand< GemmOperand::kC, true ></a></td><td class="desc">Project C operand - (0, N, M) </td></tr>
|
||||
<tr id="row_0_0_85_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ProjectOperand_3_01GemmOperand_1_1kD_00_01true_01_4.html" target="_self">ProjectOperand< GemmOperand::kD, true ></a></td><td class="desc">Project D operand - (0, N, M) </td></tr>
|
||||
<tr id="row_0_0_86_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ReshapeThreads.html" target="_self">ReshapeThreads</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_87_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ReshapeThreads_3_01Tile___00_01Threads___00_01true_01_4.html" target="_self">ReshapeThreads< Tile_, Threads_, true ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_88_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SgemmConfig.html" target="_self">SgemmConfig</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_89_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SgemmTraits.html" target="_self">SgemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_90_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_90_" class="arrow" onclick="toggleFolder('0_0_90_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SharedLoadStream.html" target="_self">SharedLoadStream</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_90_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SharedLoadStream_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_91_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SimplifiedGemmEpilogueTraits.html" target="_self">SimplifiedGemmEpilogueTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_92_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SimplifiedGemmTraits.html" target="_self">SimplifiedGemmTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_93_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1SimplifiedGemmTraitsHelper.html" target="_self">SimplifiedGemmTraitsHelper</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_94_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ThreadMultiplyAdd.html" target="_self">ThreadMultiplyAdd</a></td><td class="desc">Template performing matrix multiply-add operation within a thread </td></tr>
|
||||
<tr id="row_0_0_95_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ThreadMultiplyAdd_3_01AccumulatorsPerThread___00_01ThreadsPerWarp___00_01half_00_01half_00_01half_01_4.html" target="_self">ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half ></a></td><td class="desc">Template performing matrix multiply-add operation within a thread </td></tr>
|
||||
<tr id="row_0_0_96_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1ThreadMultiplyAdd_3_01AccumulatorsPerThread___00_01ThreadsPerWarp___00_f5353db950bbf0023472029cac4814b6.html" target="_self">ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int ></a></td><td class="desc">Template performing matrix multiply-add operation within a thread </td></tr>
|
||||
<tr id="row_0_0_97_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_97_" class="arrow" onclick="toggleFolder('0_0_97_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1WmmaGemmGlobalIteratorCd.html" target="_self">WmmaGemmGlobalIteratorCd</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_97_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1WmmaGemmGlobalIteratorCd_1_1Params.html" target="_self">Params</a></td><td class="desc">The params </td></tr>
|
||||
<tr id="row_0_0_98_" class="even" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_0_98_" class="arrow" onclick="toggleFolder('0_0_98_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1WmmaGemmGlobalIteratorCdTraits.html" target="_self">WmmaGemmGlobalIteratorCdTraits</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_0_98_0_" class="even" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1gemm_1_1WmmaGemmGlobalIteratorCdTraits_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_1_" class="even"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_1_" class="arrow" onclick="toggleFolder('0_1_')">▶</span><span class="icona"><span class="icon">N</span></span><a class="el" href="namespacecutlass_1_1platform.html" target="_self">platform</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_0_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1aligned__chunk.html" target="_self">aligned_chunk</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_1_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1aligned__storage.html" target="_self">aligned_storage</a></td><td class="desc">Std::aligned_storage </td></tr>
|
||||
<tr id="row_0_1_2_" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_1_2_" class="arrow" onclick="toggleFolder('0_1_2_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of.html" target="_self">alignment_of</a></td><td class="desc">Std::alignment_of </td></tr>
|
||||
<tr id="row_0_1_2_0_" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_1_1pad.html" target="_self">pad</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_3_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01const_01value__t_01_4.html" target="_self">alignment_of< const value_t ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_4_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01const_01volatile_01value__t_01_4.html" target="_self">alignment_of< const volatile value_t ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_5_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01double2_01_4.html" target="_self">alignment_of< double2 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_6_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01double4_01_4.html" target="_self">alignment_of< double4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_7_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01float4_01_4.html" target="_self">alignment_of< float4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_8_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01int4_01_4.html" target="_self">alignment_of< int4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_9_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01long4_01_4.html" target="_self">alignment_of< long4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_10_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01longlong2_01_4.html" target="_self">alignment_of< longlong2 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_11_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01longlong4_01_4.html" target="_self">alignment_of< longlong4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_12_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01uint4_01_4.html" target="_self">alignment_of< uint4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_13_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01ulong4_01_4.html" target="_self">alignment_of< ulong4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_14_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01ulonglong2_01_4.html" target="_self">alignment_of< ulonglong2 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_15_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01ulonglong4_01_4.html" target="_self">alignment_of< ulonglong4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_16_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1alignment__of_3_01volatile_01value__t_01_4.html" target="_self">alignment_of< volatile value_t ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_17_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1bool__constant.html" target="_self">bool_constant</a></td><td class="desc">Std::bool_constant </td></tr>
|
||||
<tr id="row_0_1_18_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1conditional.html" target="_self">conditional</a></td><td class="desc">Std::conditional (true specialization) </td></tr>
|
||||
<tr id="row_0_1_19_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1conditional_3_01false_00_01T_00_01F_01_4.html" target="_self">conditional< false, T, F ></a></td><td class="desc">Std::conditional (false specialization) </td></tr>
|
||||
<tr id="row_0_1_20_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1default__delete.html" target="_self">default_delete</a></td><td class="desc">Default deleter </td></tr>
|
||||
<tr id="row_0_1_21_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1default__delete_3_01T[]_4.html" target="_self">default_delete< T[]></a></td><td class="desc">Partial specialization for deleting array types </td></tr>
|
||||
<tr id="row_0_1_22_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1enable__if.html" target="_self">enable_if</a></td><td class="desc">Std::enable_if (true specialization) </td></tr>
|
||||
<tr id="row_0_1_23_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1enable__if_3_01false_00_01T_01_4.html" target="_self">enable_if< false, T ></a></td><td class="desc">Std::enable_if (false specialization) </td></tr>
|
||||
<tr id="row_0_1_24_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1greater.html" target="_self">greater</a></td><td class="desc">Std::greater </td></tr>
|
||||
<tr id="row_0_1_25_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1integral__constant.html" target="_self">integral_constant</a></td><td class="desc">Std::integral_constant </td></tr>
|
||||
<tr id="row_0_1_26_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__arithmetic.html" target="_self">is_arithmetic</a></td><td class="desc">Std::is_arithmetic </td></tr>
|
||||
<tr id="row_0_1_27_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__base__of.html" target="_self">is_base_of</a></td><td class="desc">Std::is_base_of </td></tr>
|
||||
<tr id="row_0_1_28_" style="display:none;"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span id="arr_0_1_28_" class="arrow" onclick="toggleFolder('0_1_28_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__base__of__helper.html" target="_self">is_base_of_helper</a></td><td class="desc">Helper for std::is_base_of </td></tr>
|
||||
<tr id="row_0_1_28_0_" style="display:none;"><td class="entry"><span style="width:64px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__base__of__helper_1_1dummy.html" target="_self">dummy</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_29_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__floating__point.html" target="_self">is_floating_point</a></td><td class="desc">Std::is_floating_point </td></tr>
|
||||
<tr id="row_0_1_30_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__fundamental.html" target="_self">is_fundamental</a></td><td class="desc">Std::is_fundamental </td></tr>
|
||||
<tr id="row_0_1_31_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral.html" target="_self">is_integral</a></td><td class="desc">Std::is_integral </td></tr>
|
||||
<tr id="row_0_1_32_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01char_01_4.html" target="_self">is_integral< char ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_33_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01const_01T_01_4.html" target="_self">is_integral< const T ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_34_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01const_01volatile_01T_01_4.html" target="_self">is_integral< const volatile T ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_35_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01int_01_4.html" target="_self">is_integral< int ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_36_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01long_01_4.html" target="_self">is_integral< long ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_37_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01long_01long_01_4.html" target="_self">is_integral< long long ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_38_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01short_01_4.html" target="_self">is_integral< short ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_39_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01signed_01char_01_4.html" target="_self">is_integral< signed char ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_40_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01unsigned_01char_01_4.html" target="_self">is_integral< unsigned char ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_41_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01unsigned_01int_01_4.html" target="_self">is_integral< unsigned int ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_42_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01unsigned_01long_01_4.html" target="_self">is_integral< unsigned long ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_43_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01unsigned_01long_01long_01_4.html" target="_self">is_integral< unsigned long long ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_44_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01unsigned_01short_01_4.html" target="_self">is_integral< unsigned short ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_45_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__integral_3_01volatile_01T_01_4.html" target="_self">is_integral< volatile T ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_46_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__pointer.html" target="_self">is_pointer</a></td><td class="desc">Std::is_pointer </td></tr>
|
||||
<tr id="row_0_1_47_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__pointer__helper.html" target="_self">is_pointer_helper</a></td><td class="desc">Helper for std::is_pointer (false specialization) </td></tr>
|
||||
<tr id="row_0_1_48_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__pointer__helper_3_01T_01_5_01_4.html" target="_self">is_pointer_helper< T * ></a></td><td class="desc">Helper for std::is_pointer (true specialization) </td></tr>
|
||||
<tr id="row_0_1_49_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__same.html" target="_self">is_same</a></td><td class="desc">Std::is_same (false specialization) </td></tr>
|
||||
<tr id="row_0_1_50_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__same_3_01A_00_01A_01_4.html" target="_self">is_same< A, A ></a></td><td class="desc">Std::is_same (true specialization) </td></tr>
|
||||
<tr id="row_0_1_51_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__trivially__copyable.html" target="_self">is_trivially_copyable</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_52_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__void.html" target="_self">is_void</a></td><td class="desc">Std::is_void </td></tr>
|
||||
<tr id="row_0_1_53_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__volatile.html" target="_self">is_volatile</a></td><td class="desc">Std::is_volatile </td></tr>
|
||||
<tr id="row_0_1_54_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1is__volatile_3_01volatile_01T_01_4.html" target="_self">is_volatile< volatile T ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_1_55_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1less.html" target="_self">less</a></td><td class="desc">Std::less </td></tr>
|
||||
<tr id="row_0_1_56_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1nullptr__t.html" target="_self">nullptr_t</a></td><td class="desc">Std::nullptr_t </td></tr>
|
||||
<tr id="row_0_1_57_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1plus.html" target="_self">plus</a></td><td class="desc">Platform::plus </td></tr>
|
||||
<tr id="row_0_1_58_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1remove__const.html" target="_self">remove_const</a></td><td class="desc">Std::remove_const (non-const specialization) </td></tr>
|
||||
<tr id="row_0_1_59_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1remove__const_3_01const_01T_01_4.html" target="_self">remove_const< const T ></a></td><td class="desc">Std::remove_const (const specialization) </td></tr>
|
||||
<tr id="row_0_1_60_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1remove__cv.html" target="_self">remove_cv</a></td><td class="desc">Std::remove_cv </td></tr>
|
||||
<tr id="row_0_1_61_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1remove__volatile.html" target="_self">remove_volatile</a></td><td class="desc">Std::remove_volatile (non-volatile specialization) </td></tr>
|
||||
<tr id="row_0_1_62_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1platform_1_1remove__volatile_3_01volatile_01T_01_4.html" target="_self">remove_volatile< volatile T ></a></td><td class="desc">Std::remove_volatile (volatile specialization) </td></tr>
|
||||
<tr id="row_0_1_63_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="classcutlass_1_1platform_1_1unique__ptr.html" target="_self">unique_ptr</a></td><td class="desc">Std::unique_ptr </td></tr>
|
||||
<tr id="row_0_2_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1AlignedStruct.html" target="_self">AlignedStruct</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_3_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromShape.html" target="_self">ComputeOffsetFromShape</a></td><td class="desc">Compute the offset for the given coordinates in a cube </td></tr>
|
||||
<tr id="row_0_4_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromShape_3_01Shape_3_011_00_01kSh___00_01kSw___00_011_01_4_01_4.html" target="_self">ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > ></a></td><td class="desc">Compute the offset for the given coordinates in a cube with one channel and a depth of 1 </td></tr>
|
||||
<tr id="row_0_5_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromShape_3_01Shape_3_011_00_01kSh___00_01kSw___00_01kSc___01_4_01_4.html" target="_self">ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > ></a></td><td class="desc">Compute the offset for the given coordinates in a cube with a depth of 1 </td></tr>
|
||||
<tr id="row_0_6_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromStrides.html" target="_self">ComputeOffsetFromStrides</a></td><td class="desc">Compute the offset for the given coordinates in a cube </td></tr>
|
||||
<tr id="row_0_7_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromStrides_3_01Shape_3_011_00_01S__h___00_01S__w___00_011_01_4_01_4.html" target="_self">ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > ></a></td><td class="desc">Compute the offset for the given coordinates in a cube with one channel and a depth of 1 </td></tr>
|
||||
<tr id="row_0_8_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeOffsetFromStrides_3_01Shape_3_011_00_01S__h___00_01S__w___00_01S__c___01_4_01_4.html" target="_self">ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > ></a></td><td class="desc">Compute the offset for the given coordinates in a cube with a depth of 1 </td></tr>
|
||||
<tr id="row_0_9_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeThreadOffsetFromStrides.html" target="_self">ComputeThreadOffsetFromStrides</a></td><td class="desc">Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_ </td></tr>
|
||||
<tr id="row_0_10_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeThreadOffsetFromStrides_3_01Shape_3_011_00_01T__h___00_01T__w___00_011_0e75281d7e02fa191f5d498e10e25dc1b.html" target="_self">ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > ></a></td><td class="desc">Specialization for D=1 and C=1 </td></tr>
|
||||
<tr id="row_0_11_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ComputeThreadOffsetFromStrides_3_01Shape_3_011_00_01T__h___00_01T__w___00_01T__dd54c41f6edb97d3c208cb7c6fe4ab9b.html" target="_self">ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > ></a></td><td class="desc">Specialization for D=1 </td></tr>
|
||||
<tr id="row_0_12_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ConstPredicateTileAdapter.html" target="_self">ConstPredicateTileAdapter</a></td><td class="desc">Adapter to enable random access to predicates via logical coordinate within a tile </td></tr>
|
||||
<tr id="row_0_13_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Convert.html" target="_self">Convert</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_14_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Convert_3_01Fragment_3_01InputScalar___00_01kScalars___01_4_00_01Fragment_3_01Ofca5985d18bcb54bc1f49355f3cee121.html" target="_self">Convert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_15_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Coord.html" target="_self">Coord</a></td><td class="desc">Statically-sized array specifying Coords within a tensor </td></tr>
|
||||
<tr id="row_0_16_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Copy.html" target="_self">Copy</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_17_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1divide__assert.html" target="_self">divide_assert</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_18_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Extent.html" target="_self">Extent</a></td><td class="desc">Returns the extent of a scalar or vector </td></tr>
|
||||
<tr id="row_0_19_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Extent_3_01Vector_3_01T_00_01Lanes_01_4_01_4.html" target="_self">Extent< Vector< T, Lanes > ></a></td><td class="desc">Returns the number of lanes of a vector if need be </td></tr>
|
||||
<tr id="row_0_20_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Extent_3_01Vector_3_01T_00_01Lanes_01_4_01const_01_4.html" target="_self">Extent< Vector< T, Lanes > const ></a></td><td class="desc">Returns the number of lanes of a vector if need be </td></tr>
|
||||
<tr id="row_0_21_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Fragment.html" target="_self">Fragment</a></td><td class="desc">A template defining <a class="el" href="group__fragment__concept.html">Fragment Concept</a> </td></tr>
|
||||
<tr id="row_0_22_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentConstIterator.html" target="_self">FragmentConstIterator</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_23_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentIterator.html" target="_self">FragmentIterator</a></td><td class="desc">A template defining <a class="el" href="group__fragment__iterator__concept.html">Fragment Iterator Concept</a> </td></tr>
|
||||
<tr id="row_0_24_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentLoad.html" target="_self">FragmentLoad</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_25_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentLoad_3_01IteratorFragment_1_1kScalar_00_01kAccessSize_00_01Scalar___00_9bf6f8f94e2cd7f3702b853d418a9863.html" target="_self">FragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_26_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentLoad_3_01IteratorFragment_1_1kWmmaMatrix_00_01kAccessSize_00_01Scalar__a157bdca477e8efca5bc9cda0db6db8e.html" target="_self">FragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_27_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentStore.html" target="_self">FragmentStore</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_28_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentStore_3_01IteratorFragment_1_1kScalar_00_01kAccessSize_00_01Scalar___0087787c90510d0c4c07703b5a90c263de.html" target="_self">FragmentStore< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_29_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1FragmentStore_3_01IteratorFragment_1_1kWmmaMatrix_00_01kAccessSize_00_01Scalar_00c2299561c3ffbb17f8afc6add32eba.html" target="_self">FragmentStore< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_30_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1GemmOperand.html" target="_self">GemmOperand</a></td><td class="desc">Gemm operand - D = A * B + C </td></tr>
|
||||
<tr id="row_0_31_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Identity.html" target="_self">Identity</a></td><td class="desc">Describes identity elements </td></tr>
|
||||
<tr id="row_0_32_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1is__pow2.html" target="_self">is_pow2</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_33_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1IteratorAdvance.html" target="_self">IteratorAdvance</a></td><td class="desc">Specifies dimension in which post-increment accesses advance </td></tr>
|
||||
<tr id="row_0_34_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1IteratorFragment.html" target="_self">IteratorFragment</a></td><td class="desc">Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix </td></tr>
|
||||
<tr id="row_0_35_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Load.html" target="_self">Load</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_36_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Load_3_01double_00_012_00_01Memory___00_01true_00_0116_01_4.html" target="_self">Load< double, 2, Memory_, true, 16 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_37_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Load_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_0116_01_4.html" target="_self">Load< Scalar_, Lanes_, Memory_, true, 16 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_38_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Load_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_014_01_4.html" target="_self">Load< Scalar_, Lanes_, Memory_, true, 4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_39_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Load_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_018_01_4.html" target="_self">Load< Scalar_, Lanes_, Memory_, true, 8 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_40_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1log2__down.html" target="_self">log2_down</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_41_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1log2__down_3_01N_00_011_00_01Count_01_4.html" target="_self">log2_down< N, 1, Count ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_42_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1log2__up.html" target="_self">log2_up</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_43_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1log2__up_3_01N_00_011_00_01Count_01_4.html" target="_self">log2_up< N, 1, Count ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_44_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1MatrixLayout.html" target="_self">MatrixLayout</a></td><td class="desc">Describes layouts of matrices </td></tr>
|
||||
<tr id="row_0_45_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1MemorySpace.html" target="_self">MemorySpace</a></td><td class="desc">Enum to specify which memory space data resides in </td></tr>
|
||||
<tr id="row_0_46_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1PredicateTileAdapter.html" target="_self">PredicateTileAdapter</a></td><td class="desc">Adapter to enable random access to predicates via logical coordinate within a tile </td></tr>
|
||||
<tr id="row_0_47_" class="even"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_47_" class="arrow" onclick="toggleFolder('0_47_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1PredicateVector.html" target="_self">PredicateVector</a></td><td class="desc">Statically sized array of bits implementing </td></tr>
|
||||
<tr id="row_0_47_0_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html" target="_self">ConstIterator</a></td><td class="desc">A const iterator implementing <a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a> enabling sequential read-only access to prediactes </td></tr>
|
||||
<tr id="row_0_47_1_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html" target="_self">Iterator</a></td><td class="desc">An iterator implementing <a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a> enabling sequential read and write access to predicates </td></tr>
|
||||
<tr id="row_0_47_2_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1PredicateVector_1_1TrivialIterator.html" target="_self">TrivialIterator</a></td><td class="desc"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html" title="An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...">Iterator</a> that always returns true </td></tr>
|
||||
<tr id="row_0_48_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ReshapeTile.html" target="_self">ReshapeTile</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_49_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ReshapeTile_3_01Tile___00_01kAccessSize___00_01true_01_4.html" target="_self">ReshapeTile< Tile_, kAccessSize_, true ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_50_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Shape.html" target="_self">Shape</a></td><td class="desc">A <a class="el" href="structcutlass_1_1Shape.html" title="A Shape implementing Layout Concept describing the dimensions of a cube. ">Shape</a> implementing <a class="el" href="group__layout__concept.html">Layout Concept</a> describing the dimensions of a cube </td></tr>
|
||||
<tr id="row_0_51_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeAdd.html" target="_self">ShapeAdd</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_52_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeCount.html" target="_self">ShapeCount</a></td><td class="desc">Compute derived counted of a <a class="el" href="group__layout__concept.html">Layout Concept</a> based class </td></tr>
|
||||
<tr id="row_0_53_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeDiv.html" target="_self">ShapeDiv</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_54_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeMax.html" target="_self">ShapeMax</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_55_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeMin.html" target="_self">ShapeMin</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_56_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeMul.html" target="_self">ShapeMul</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_57_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeScale.html" target="_self">ShapeScale</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_58_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeStrides.html" target="_self">ShapeStrides</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_59_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1ShapeSub.html" target="_self">ShapeSub</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_60_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1sqrt__est.html" target="_self">sqrt_est</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_61_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1StorageType.html" target="_self">StorageType</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_62_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1StorageType_3_011_01_4.html" target="_self">StorageType< 1 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_63_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1StorageType_3_012_01_4.html" target="_self">StorageType< 2 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_64_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1StorageType_3_014_01_4.html" target="_self">StorageType< 4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_65_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Store.html" target="_self">Store</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_66_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Store_3_01double_00_012_00_01Memory___00_01true_00_0116_01_4.html" target="_self">Store< double, 2, Memory_, true, 16 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_67_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Store_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_0116_01_4.html" target="_self">Store< Scalar_, Lanes_, Memory_, true, 16 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_68_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Store_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_014_01_4.html" target="_self">Store< Scalar_, Lanes_, Memory_, true, 4 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_69_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Store_3_01Scalar___00_01Lanes___00_01Memory___00_01true_00_018_01_4.html" target="_self">Store< Scalar_, Lanes_, Memory_, true, 8 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_70_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="classcutlass_1_1TensorRef.html" target="_self">TensorRef</a></td><td class="desc">Structure modeling a pointer and stride into a tensor </td></tr>
|
||||
<tr id="row_0_71_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="classcutlass_1_1TensorView.html" target="_self">TensorView</a></td><td class="desc">Host-side reference implementation of tensor operations </td></tr>
|
||||
<tr id="row_0_72_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TiledThreadOffset.html" target="_self">TiledThreadOffset</a></td><td class="desc">Basic thread offset function computed from a thread shape </td></tr>
|
||||
<tr id="row_0_73_" class="even"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_73_" class="arrow" onclick="toggleFolder('0_73_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileIteratorBase.html" target="_self">TileIteratorBase</a></td><td class="desc">Iterator for accessing a stripmined tile in memory </td></tr>
|
||||
<tr id="row_0_73_0_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileIteratorBase_1_1Params.html" target="_self">Params</a></td><td class="desc">Parameters to the iterator </td></tr>
|
||||
<tr id="row_0_74_"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_74_" class="arrow" onclick="toggleFolder('0_74_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileLoadIterator.html" target="_self">TileLoadIterator</a></td><td class="desc">An iterator implementing <a class="el" href="group__tile__load__iterator__concept.html">Tile Load Iterator Concept</a> for loading a tile from memory </td></tr>
|
||||
<tr id="row_0_74_0_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileLoadIterator_1_1Params.html" target="_self">Params</a></td><td class="desc">Parameters </td></tr>
|
||||
<tr id="row_0_75_" class="even"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_75_" class="arrow" onclick="toggleFolder('0_75_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileStoreIterator.html" target="_self">TileStoreIterator</a></td><td class="desc">An iterator implementing <a class="el" href="group__tile__store__iterator__concept.html">Tile Store Iterator Concept</a> for storing a tile to memory </td></tr>
|
||||
<tr id="row_0_75_0_" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileStoreIterator_1_1Params.html" target="_self">Params</a></td><td class="desc">Parameters </td></tr>
|
||||
<tr id="row_0_76_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraits.html" target="_self">TileTraits</a></td><td class="desc">A template defining <a class="el" href="group__tile__traits__concept.html">Tile Traits Concept</a> </td></tr>
|
||||
<tr id="row_0_77_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraitsContiguousMajor.html" target="_self">TileTraitsContiguousMajor</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_78_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraitsStandard.html" target="_self">TileTraitsStandard</a></td><td class="desc">Chooses 'best' shape to enable warp raking along contiguous dimension if possible </td></tr>
|
||||
<tr id="row_0_79_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraitsStrideMajor.html" target="_self">TileTraitsStrideMajor</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_80_"><td class="entry"><span style="width:16px;display:inline-block;"> </span><span id="arr_0_80_" class="arrow" onclick="toggleFolder('0_80_')">▶</span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraitsWarpRake.html" target="_self">TileTraitsWarpRake</a></td><td class="desc">Tiling in which warps rake across the contiguous dimension </td></tr>
|
||||
<tr id="row_0_80_0_" class="even" style="display:none;"><td class="entry"><span style="width:48px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TileTraitsWarpRake_1_1ThreadOffset.html" target="_self">ThreadOffset</a></td><td class="desc">Computes the thread offset in (H, W) based on thread ID </td></tr>
|
||||
<tr id="row_0_81_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1TrivialPredicateTileAdapter.html" target="_self">TrivialPredicateTileAdapter</a></td><td class="desc">Always returns true predicate </td></tr>
|
||||
<tr id="row_0_82_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1Vector.html" target="_self">Vector</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_83_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="unioncutlass_1_1Vector_3_01half_00_01kLanes___01_4.html" target="_self">Vector< half, kLanes_ ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_84_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Vectorize.html" target="_self">Vectorize</a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_85_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1Vectorize_3_01Element___00_011_01_4.html" target="_self">Vectorize< Element_, 1 ></a></td><td class="desc"></td></tr>
|
||||
<tr id="row_0_86_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1VectorTraits.html" target="_self">VectorTraits</a></td><td class="desc">Traits describing properties of vectors and scalar-as-vectors </td></tr>
|
||||
<tr id="row_0_87_" class="even"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1VectorTraits_3_01Vector_3_01T_00_01Lanes_01_4_01_4.html" target="_self">VectorTraits< Vector< T, Lanes > ></a></td><td class="desc">Partial specialization for actual <a class="el" href="unioncutlass_1_1Vector.html">cutlass::Vector</a> </td></tr>
|
||||
<tr id="row_0_88_"><td class="entry"><span style="width:32px;display:inline-block;"> </span><span class="icona"><span class="icon">C</span></span><a class="el" href="structcutlass_1_1VectorTraits_3_01Vector_3_01T_00_01Lanes_01_4_01const_01_4.html" target="_self">VectorTraits< Vector< T, Lanes > const ></a></td><td class="desc">Partial specialization for actual <a class="el" href="unioncutlass_1_1Vector.html">cutlass::Vector</a> </td></tr>
|
||||
</table>
|
||||
</div><!-- directory -->
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated on Thu May 3 2018 16:36:15 for Cutlass by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.14
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
BIN
docs/bc_s.png
Normal file
BIN
docs/bc_s.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 682 B |
BIN
docs/bdwn.png
Normal file
BIN
docs/bdwn.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 147 B |
@ -0,0 +1,99 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.14"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Cutlass: Member List</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" async src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">Cutlass
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.14 -->
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
/* @license-end */
|
||||
</script>
|
||||
<script type="text/javascript" src="menudata.js"></script>
|
||||
<script type="text/javascript" src="menu.js"></script>
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
$(function() {
|
||||
initMenu('',true,false,'search.php','Search');
|
||||
$(document).ready(function() { init_search(); });
|
||||
});
|
||||
/* @license-end */</script>
|
||||
<div id="main-nav"></div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="namespacecutlass.html">cutlass</a></li><li class="navelem"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a></li><li class="navelem"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="headertitle">
|
||||
<div class="title">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator Member List</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>This is the complete list of members for <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a>, including all inherited members.</p>
|
||||
<table class="directory">
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a1216aab9c567ec0d4232019008ef3ea7">ConstIterator</a>(ConstIterator const &it)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a590e4f4533c87162c0b79e8d876a8fda">ConstIterator</a>(PredicateVector const &_vec, int _start=0)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a3d06715a77740034697686a7977cb685">operator!=</a>(ConstIterator const &it) const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#abbc2bceb6cf8d7f168b8a00eb48c0946">operator*</a>() const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a10ee4bb2f206432aa5ee1a83cb046b70">operator++</a>()</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a977a99af3166a58d5bc5a613a1abe7d5">operator++</a>(int)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a2763012a9284e97650b14e20c5668286">operator--</a>()</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a2910a714d34a688b8ea560ea2933436b">operator--</a>(int)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#aa2d03d88ac23051803d010f78157c357">operator==</a>(ConstIterator const &it) const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
</table></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated on Thu May 3 2018 16:36:13 for Cutlass by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.14
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
389
docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html
Normal file
389
docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html
Normal file
@ -0,0 +1,389 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.14"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Cutlass: cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator Class Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" async src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">Cutlass
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.14 -->
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
/* @license-end */
|
||||
</script>
|
||||
<script type="text/javascript" src="menudata.js"></script>
|
||||
<script type="text/javascript" src="menu.js"></script>
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
$(function() {
|
||||
initMenu('',true,false,'search.php','Search');
|
||||
$(document).ready(function() { init_search(); });
|
||||
});
|
||||
/* @license-end */</script>
|
||||
<div id="main-nav"></div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="namespacecutlass.html">cutlass</a></li><li class="navelem"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a></li><li class="navelem"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#pub-methods">Public Member Functions</a> |
|
||||
<a href="classcutlass_1_1PredicateVector_1_1ConstIterator-members.html">List of all members</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator Class Reference<div class="ingroups"><a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a></div></div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>A const iterator implementing <a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a> enabling sequential read-only access to prediactes.
|
||||
</p>
|
||||
|
||||
<p><code>#include <<a class="el" href="predicate__vector_8h_source.html">predicate_vector.h</a>></code></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="pub-methods"></a>
|
||||
Public Member Functions</h2></td></tr>
|
||||
<tr class="memitem:a1216aab9c567ec0d4232019008ef3ea7"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a1216aab9c567ec0d4232019008ef3ea7">ConstIterator</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const &it)</td></tr>
|
||||
<tr class="memdesc:a1216aab9c567ec0d4232019008ef3ea7"><td class="mdescLeft"> </td><td class="mdescRight"><a class="el" href="structcutlass_1_1Copy.html">Copy</a> constructor. <a href="#a1216aab9c567ec0d4232019008ef3ea7">More...</a><br /></td></tr>
|
||||
<tr class="separator:a1216aab9c567ec0d4232019008ef3ea7"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a590e4f4533c87162c0b79e8d876a8fda"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a590e4f4533c87162c0b79e8d876a8fda">ConstIterator</a> (<a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a> const &_vec, int _start=0)</td></tr>
|
||||
<tr class="separator:a590e4f4533c87162c0b79e8d876a8fda"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a10ee4bb2f206432aa5ee1a83cb046b70"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> & </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a10ee4bb2f206432aa5ee1a83cb046b70">operator++</a> ()</td></tr>
|
||||
<tr class="memdesc:a10ee4bb2f206432aa5ee1a83cb046b70"><td class="mdescLeft"> </td><td class="mdescRight">Pre-increment. <a href="#a10ee4bb2f206432aa5ee1a83cb046b70">More...</a><br /></td></tr>
|
||||
<tr class="separator:a10ee4bb2f206432aa5ee1a83cb046b70"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a2763012a9284e97650b14e20c5668286"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> & </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a2763012a9284e97650b14e20c5668286">operator--</a> ()</td></tr>
|
||||
<tr class="memdesc:a2763012a9284e97650b14e20c5668286"><td class="mdescLeft"> </td><td class="mdescRight">Pre-decrement. <a href="#a2763012a9284e97650b14e20c5668286">More...</a><br /></td></tr>
|
||||
<tr class="separator:a2763012a9284e97650b14e20c5668286"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a977a99af3166a58d5bc5a613a1abe7d5"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a977a99af3166a58d5bc5a613a1abe7d5">operator++</a> (int)</td></tr>
|
||||
<tr class="memdesc:a977a99af3166a58d5bc5a613a1abe7d5"><td class="mdescLeft"> </td><td class="mdescRight">Post-increment. <a href="#a977a99af3166a58d5bc5a613a1abe7d5">More...</a><br /></td></tr>
|
||||
<tr class="separator:a977a99af3166a58d5bc5a613a1abe7d5"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a2910a714d34a688b8ea560ea2933436b"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a2910a714d34a688b8ea560ea2933436b">operator--</a> (int)</td></tr>
|
||||
<tr class="memdesc:a2910a714d34a688b8ea560ea2933436b"><td class="mdescLeft"> </td><td class="mdescRight">Post-decrement. <a href="#a2910a714d34a688b8ea560ea2933436b">More...</a><br /></td></tr>
|
||||
<tr class="separator:a2910a714d34a688b8ea560ea2933436b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aa2d03d88ac23051803d010f78157c357"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#aa2d03d88ac23051803d010f78157c357">operator==</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const &it) const</td></tr>
|
||||
<tr class="memdesc:aa2d03d88ac23051803d010f78157c357"><td class="mdescLeft"> </td><td class="mdescRight">Returns true if iterators point to the same bit. <a href="#aa2d03d88ac23051803d010f78157c357">More...</a><br /></td></tr>
|
||||
<tr class="separator:aa2d03d88ac23051803d010f78157c357"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a3d06715a77740034697686a7977cb685"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#a3d06715a77740034697686a7977cb685">operator!=</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const &it) const</td></tr>
|
||||
<tr class="memdesc:a3d06715a77740034697686a7977cb685"><td class="mdescLeft"> </td><td class="mdescRight">Returns false if iterators point to the same bit. <a href="#a3d06715a77740034697686a7977cb685">More...</a><br /></td></tr>
|
||||
<tr class="separator:a3d06715a77740034697686a7977cb685"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:abbc2bceb6cf8d7f168b8a00eb48c0946"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html#abbc2bceb6cf8d7f168b8a00eb48c0946">operator*</a> () const</td></tr>
|
||||
<tr class="memdesc:abbc2bceb6cf8d7f168b8a00eb48c0946"><td class="mdescLeft"> </td><td class="mdescRight">Dereferences iterator. <a href="#abbc2bceb6cf8d7f168b8a00eb48c0946">More...</a><br /></td></tr>
|
||||
<tr class="separator:abbc2bceb6cf8d7f168b8a00eb48c0946"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Constructor & Destructor Documentation</h2>
|
||||
<a id="a1216aab9c567ec0d4232019008ef3ea7"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a1216aab9c567ec0d4232019008ef3ea7">◆ </a></span>ConstIterator() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::ConstIterator </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a590e4f4533c87162c0b79e8d876a8fda"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a590e4f4533c87162c0b79e8d876a8fda">◆ </a></span>ConstIterator() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::ConstIterator </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a> const & </td>
|
||||
<td class="paramname"><em>_vec</em>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"><em>_start</em> = <code>0</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>)</td>
|
||||
<td></td><td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<h2 class="groupheader">Member Function Documentation</h2>
|
||||
<a id="a3d06715a77740034697686a7977cb685"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a3d06715a77740034697686a7977cb685">◆ </a></span>operator!=()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator!= </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="abbc2bceb6cf8d7f168b8a00eb48c0946"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#abbc2bceb6cf8d7f168b8a00eb48c0946">◆ </a></span>operator*()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator* </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a10ee4bb2f206432aa5ee1a83cb046b70"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a10ee4bb2f206432aa5ee1a83cb046b70">◆ </a></span>operator++() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a>& <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator++ </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a977a99af3166a58d5bc5a613a1abe7d5"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a977a99af3166a58d5bc5a613a1abe7d5">◆ </a></span>operator++() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator++ </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a2763012a9284e97650b14e20c5668286"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a2763012a9284e97650b14e20c5668286">◆ </a></span>operator--() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a>& <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator-- </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a2910a714d34a688b8ea560ea2933436b"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a2910a714d34a688b8ea560ea2933436b">◆ </a></span>operator--() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator-- </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aa2d03d88ac23051803d010f78157c357"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aa2d03d88ac23051803d010f78157c357">◆ </a></span>operator==()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIterator::operator== </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1ConstIterator.html">ConstIterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<hr/>The documentation for this class was generated from the following file:<ul>
|
||||
<li><a class="el" href="predicate__vector_8h_source.html">predicate_vector.h</a></li>
|
||||
</ul>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated on Thu May 3 2018 16:36:13 for Cutlass by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.14
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
101
docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html
Normal file
101
docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html
Normal file
@ -0,0 +1,101 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.14"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Cutlass: Member List</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" async src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">Cutlass
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.14 -->
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
/* @license-end */
|
||||
</script>
|
||||
<script type="text/javascript" src="menudata.js"></script>
|
||||
<script type="text/javascript" src="menu.js"></script>
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
$(function() {
|
||||
initMenu('',true,false,'search.php','Search');
|
||||
$(document).ready(function() { init_search(); });
|
||||
});
|
||||
/* @license-end */</script>
|
||||
<div id="main-nav"></div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="namespacecutlass.html">cutlass</a></li><li class="navelem"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a></li><li class="navelem"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="headertitle">
|
||||
<div class="title">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator Member List</div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>This is the complete list of members for <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a>, including all inherited members.</p>
|
||||
<table class="directory">
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#af035589126434bd2dbef4000cd864b8b">get</a>()</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a91b7d25cbd64e696ef23c87671f0b077">Iterator</a>(Iterator const &it)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a83c2f584bd061f0b9b6b2a6cddf5b038">Iterator</a>(PredicateVector &_vec, int _start=0)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a08cb4d1395b88a4451fbb1a27e010887">operator!=</a>(Iterator const &it) const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a049b568e0f5de011ee76ce79bcedbab4">operator*</a>() const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a7dddc0a6b5c958156beef29bedfd1bd3">operator++</a>()</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a6c7333ad14d545cafc707e78752bf1e3">operator++</a>(int)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a69fb5b24eeb43331b7401768e8584e61">operator--</a>()</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#aad709a11f43b84c88e3ce3a0394f8e8a">operator--</a>(int)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a5c5266fcef67c7b263682c4bc4a5000e">operator==</a>(Iterator const &it) const</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
<tr class="even"><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#aadfd039b5622098c9e46706a27122575">set</a>(bool value=true)</td><td class="entry"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
|
||||
</table></div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated on Thu May 3 2018 16:36:13 for Cutlass by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.14
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
451
docs/classcutlass_1_1PredicateVector_1_1Iterator.html
Normal file
451
docs/classcutlass_1_1PredicateVector_1_1Iterator.html
Normal file
@ -0,0 +1,451 @@
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/xhtml;charset=UTF-8"/>
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=9"/>
|
||||
<meta name="generator" content="Doxygen 1.8.14"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1"/>
|
||||
<title>Cutlass: cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator Class Reference</title>
|
||||
<link href="tabs.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="jquery.js"></script>
|
||||
<script type="text/javascript" src="dynsections.js"></script>
|
||||
<link href="search/search.css" rel="stylesheet" type="text/css"/>
|
||||
<script type="text/javascript" src="search/searchdata.js"></script>
|
||||
<script type="text/javascript" src="search/search.js"></script>
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
extensions: ["tex2jax.js"],
|
||||
jax: ["input/TeX","output/HTML-CSS"],
|
||||
});
|
||||
</script><script type="text/javascript" async src="http://cdn.mathjax.org/mathjax/latest/MathJax.js"></script>
|
||||
<link href="doxygen.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="top"><!-- do not remove this div, it is closed by doxygen! -->
|
||||
<div id="titlearea">
|
||||
<table cellspacing="0" cellpadding="0">
|
||||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<td id="projectalign" style="padding-left: 0.5em;">
|
||||
<div id="projectname">Cutlass
|
||||
</div>
|
||||
<div id="projectbrief">CUDA Templates for Linear Algebra Subroutines and Solvers</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<!-- end header part -->
|
||||
<!-- Generated by Doxygen 1.8.14 -->
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
var searchBox = new SearchBox("searchBox", "search",false,'Search');
|
||||
/* @license-end */
|
||||
</script>
|
||||
<script type="text/javascript" src="menudata.js"></script>
|
||||
<script type="text/javascript" src="menu.js"></script>
|
||||
<script type="text/javascript">
|
||||
/* @license magnet:?xt=urn:btih:cf05388f2679ee054f2beb29a391d25f4e673ac3&dn=gpl-2.0.txt GPL-v2 */
|
||||
$(function() {
|
||||
initMenu('',true,false,'search.php','Search');
|
||||
$(document).ready(function() { init_search(); });
|
||||
});
|
||||
/* @license-end */</script>
|
||||
<div id="main-nav"></div>
|
||||
<!-- window showing the filter options -->
|
||||
<div id="MSearchSelectWindow"
|
||||
onmouseover="return searchBox.OnSearchSelectShow()"
|
||||
onmouseout="return searchBox.OnSearchSelectHide()"
|
||||
onkeydown="return searchBox.OnSearchSelectKey(event)">
|
||||
</div>
|
||||
|
||||
<!-- iframe showing the search results (closed by default) -->
|
||||
<div id="MSearchResultsWindow">
|
||||
<iframe src="javascript:void(0)" frameborder="0"
|
||||
name="MSearchResults" id="MSearchResults">
|
||||
</iframe>
|
||||
</div>
|
||||
|
||||
<div id="nav-path" class="navpath">
|
||||
<ul>
|
||||
<li class="navelem"><a class="el" href="namespacecutlass.html">cutlass</a></li><li class="navelem"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a></li><li class="navelem"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a></li> </ul>
|
||||
</div>
|
||||
</div><!-- top -->
|
||||
<div class="header">
|
||||
<div class="summary">
|
||||
<a href="#pub-methods">Public Member Functions</a> |
|
||||
<a href="classcutlass_1_1PredicateVector_1_1Iterator-members.html">List of all members</a> </div>
|
||||
<div class="headertitle">
|
||||
<div class="title">cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator Class Reference<div class="ingroups"><a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a></div></div> </div>
|
||||
</div><!--header-->
|
||||
<div class="contents">
|
||||
|
||||
<p>An iterator implementing <a class="el" href="group__predicate__iterator__concept.html">Predicate Iterator Concept</a> enabling sequential read and write access to predicates.
|
||||
</p>
|
||||
|
||||
<p><code>#include <<a class="el" href="predicate__vector_8h_source.html">predicate_vector.h</a>></code></p>
|
||||
<table class="memberdecls">
|
||||
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a name="pub-methods"></a>
|
||||
Public Member Functions</h2></td></tr>
|
||||
<tr class="memitem:a91b7d25cbd64e696ef23c87671f0b077"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a91b7d25cbd64e696ef23c87671f0b077">Iterator</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const &it)</td></tr>
|
||||
<tr class="memdesc:a91b7d25cbd64e696ef23c87671f0b077"><td class="mdescLeft"> </td><td class="mdescRight"><a class="el" href="structcutlass_1_1Copy.html">Copy</a> constructor. <a href="#a91b7d25cbd64e696ef23c87671f0b077">More...</a><br /></td></tr>
|
||||
<tr class="separator:a91b7d25cbd64e696ef23c87671f0b077"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a83c2f584bd061f0b9b6b2a6cddf5b038"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a83c2f584bd061f0b9b6b2a6cddf5b038">Iterator</a> (<a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a> &_vec, int _start=0)</td></tr>
|
||||
<tr class="memdesc:a83c2f584bd061f0b9b6b2a6cddf5b038"><td class="mdescLeft"> </td><td class="mdescRight">Constructs an iterator from a <a class="el" href="structcutlass_1_1PredicateVector.html" title="Statically sized array of bits implementing. ">PredicateVector</a>. <a href="#a83c2f584bd061f0b9b6b2a6cddf5b038">More...</a><br /></td></tr>
|
||||
<tr class="separator:a83c2f584bd061f0b9b6b2a6cddf5b038"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a7dddc0a6b5c958156beef29bedfd1bd3"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> & </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a7dddc0a6b5c958156beef29bedfd1bd3">operator++</a> ()</td></tr>
|
||||
<tr class="memdesc:a7dddc0a6b5c958156beef29bedfd1bd3"><td class="mdescLeft"> </td><td class="mdescRight">Pre-increment. <a href="#a7dddc0a6b5c958156beef29bedfd1bd3">More...</a><br /></td></tr>
|
||||
<tr class="separator:a7dddc0a6b5c958156beef29bedfd1bd3"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a69fb5b24eeb43331b7401768e8584e61"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> & </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a69fb5b24eeb43331b7401768e8584e61">operator--</a> ()</td></tr>
|
||||
<tr class="memdesc:a69fb5b24eeb43331b7401768e8584e61"><td class="mdescLeft"> </td><td class="mdescRight">Pre-decrement. <a href="#a69fb5b24eeb43331b7401768e8584e61">More...</a><br /></td></tr>
|
||||
<tr class="separator:a69fb5b24eeb43331b7401768e8584e61"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a6c7333ad14d545cafc707e78752bf1e3"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a6c7333ad14d545cafc707e78752bf1e3">operator++</a> (int)</td></tr>
|
||||
<tr class="memdesc:a6c7333ad14d545cafc707e78752bf1e3"><td class="mdescLeft"> </td><td class="mdescRight">Post-increment. <a href="#a6c7333ad14d545cafc707e78752bf1e3">More...</a><br /></td></tr>
|
||||
<tr class="separator:a6c7333ad14d545cafc707e78752bf1e3"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aad709a11f43b84c88e3ce3a0394f8e8a"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#aad709a11f43b84c88e3ce3a0394f8e8a">operator--</a> (int)</td></tr>
|
||||
<tr class="memdesc:aad709a11f43b84c88e3ce3a0394f8e8a"><td class="mdescLeft"> </td><td class="mdescRight">Post-decrement. <a href="#aad709a11f43b84c88e3ce3a0394f8e8a">More...</a><br /></td></tr>
|
||||
<tr class="separator:aad709a11f43b84c88e3ce3a0394f8e8a"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a5c5266fcef67c7b263682c4bc4a5000e"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a5c5266fcef67c7b263682c4bc4a5000e">operator==</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const &it) const</td></tr>
|
||||
<tr class="memdesc:a5c5266fcef67c7b263682c4bc4a5000e"><td class="mdescLeft"> </td><td class="mdescRight">Returns true if iterators point to the same bit. <a href="#a5c5266fcef67c7b263682c4bc4a5000e">More...</a><br /></td></tr>
|
||||
<tr class="separator:a5c5266fcef67c7b263682c4bc4a5000e"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a08cb4d1395b88a4451fbb1a27e010887"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a08cb4d1395b88a4451fbb1a27e010887">operator!=</a> (<a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const &it) const</td></tr>
|
||||
<tr class="memdesc:a08cb4d1395b88a4451fbb1a27e010887"><td class="mdescLeft"> </td><td class="mdescRight">Returns false if iterators point to the same bit. <a href="#a08cb4d1395b88a4451fbb1a27e010887">More...</a><br /></td></tr>
|
||||
<tr class="separator:a08cb4d1395b88a4451fbb1a27e010887"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:af035589126434bd2dbef4000cd864b8b"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#af035589126434bd2dbef4000cd864b8b">get</a> ()</td></tr>
|
||||
<tr class="memdesc:af035589126434bd2dbef4000cd864b8b"><td class="mdescLeft"> </td><td class="mdescRight">Gets the bit at the pointed to location. <a href="#af035589126434bd2dbef4000cd864b8b">More...</a><br /></td></tr>
|
||||
<tr class="separator:af035589126434bd2dbef4000cd864b8b"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:a049b568e0f5de011ee76ce79bcedbab4"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#a049b568e0f5de011ee76ce79bcedbab4">operator*</a> () const</td></tr>
|
||||
<tr class="memdesc:a049b568e0f5de011ee76ce79bcedbab4"><td class="mdescLeft"> </td><td class="mdescRight">Dereferences iterator. <a href="#a049b568e0f5de011ee76ce79bcedbab4">More...</a><br /></td></tr>
|
||||
<tr class="separator:a049b568e0f5de011ee76ce79bcedbab4"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
<tr class="memitem:aadfd039b5622098c9e46706a27122575"><td class="memItemLeft" align="right" valign="top"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> void </td><td class="memItemRight" valign="bottom"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html#aadfd039b5622098c9e46706a27122575">set</a> (bool value=true)</td></tr>
|
||||
<tr class="memdesc:aadfd039b5622098c9e46706a27122575"><td class="mdescLeft"> </td><td class="mdescRight">Sets the bit at the pointed to location. <a href="#aadfd039b5622098c9e46706a27122575">More...</a><br /></td></tr>
|
||||
<tr class="separator:aadfd039b5622098c9e46706a27122575"><td class="memSeparator" colspan="2"> </td></tr>
|
||||
</table>
|
||||
<h2 class="groupheader">Constructor & Destructor Documentation</h2>
|
||||
<a id="a91b7d25cbd64e696ef23c87671f0b077"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a91b7d25cbd64e696ef23c87671f0b077">◆ </a></span>Iterator() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::Iterator </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a83c2f584bd061f0b9b6b2a6cddf5b038"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a83c2f584bd061f0b9b6b2a6cddf5b038">◆ </a></span>Iterator() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::Iterator </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="structcutlass_1_1PredicateVector.html">PredicateVector</a> & </td>
|
||||
<td class="paramname"><em>_vec</em>, </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="paramkey"></td>
|
||||
<td></td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"><em>_start</em> = <code>0</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>)</td>
|
||||
<td></td><td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<h2 class="groupheader">Member Function Documentation</h2>
|
||||
<a id="af035589126434bd2dbef4000cd864b8b"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#af035589126434bd2dbef4000cd864b8b">◆ </a></span>get()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::get </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a08cb4d1395b88a4451fbb1a27e010887"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a08cb4d1395b88a4451fbb1a27e010887">◆ </a></span>operator!=()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator!= </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a049b568e0f5de011ee76ce79bcedbab4"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a049b568e0f5de011ee76ce79bcedbab4">◆ </a></span>operator*()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator* </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a7dddc0a6b5c958156beef29bedfd1bd3"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a7dddc0a6b5c958156beef29bedfd1bd3">◆ </a></span>operator++() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a>& <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator++ </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a6c7333ad14d545cafc707e78752bf1e3"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a6c7333ad14d545cafc707e78752bf1e3">◆ </a></span>operator++() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator++ </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a69fb5b24eeb43331b7401768e8584e61"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a69fb5b24eeb43331b7401768e8584e61">◆ </a></span>operator--() <span class="overload">[1/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a>& <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator-- </td>
|
||||
<td>(</td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aad709a11f43b84c88e3ce3a0394f8e8a"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aad709a11f43b84c88e3ce3a0394f8e8a">◆ </a></span>operator--() <span class="overload">[2/2]</span></h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> <a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator-- </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">int </td>
|
||||
<td class="paramname"></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="a5c5266fcef67c7b263682c4bc4a5000e"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#a5c5266fcef67c7b263682c4bc4a5000e">◆ </a></span>operator==()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> bool <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::operator== </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype"><a class="el" href="classcutlass_1_1PredicateVector_1_1Iterator.html">Iterator</a> const & </td>
|
||||
<td class="paramname"><em>it</em></td><td>)</td>
|
||||
<td> const</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<a id="aadfd039b5622098c9e46706a27122575"></a>
|
||||
<h2 class="memtitle"><span class="permalink"><a href="#aadfd039b5622098c9e46706a27122575">◆ </a></span>set()</h2>
|
||||
|
||||
<div class="memitem">
|
||||
<div class="memproto">
|
||||
<div class="memtemplate">
|
||||
template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ = 0> </div>
|
||||
<table class="mlabels">
|
||||
<tr>
|
||||
<td class="mlabels-left">
|
||||
<table class="memname">
|
||||
<tr>
|
||||
<td class="memname"><a class="el" href="cutlass_8h.html#a28c2443a142676d3d71effdae1a986b1">CUTLASS_HOST_DEVICE</a> void <a class="el" href="structcutlass_1_1PredicateVector.html">cutlass::PredicateVector</a>< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator::set </td>
|
||||
<td>(</td>
|
||||
<td class="paramtype">bool </td>
|
||||
<td class="paramname"><em>value</em> = <code>true</code></td><td>)</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td class="mlabels-right">
|
||||
<span class="mlabels"><span class="mlabel">inline</span></span> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</div><div class="memdoc">
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<hr/>The documentation for this class was generated from the following file:<ul>
|
||||
<li><a class="el" href="predicate__vector_8h_source.html">predicate_vector.h</a></li>
|
||||
</ul>
|
||||
</div><!-- contents -->
|
||||
<!-- start footer part -->
|
||||
<hr class="footer"/><address class="footer"><small>
|
||||
Generated on Thu May 3 2018 16:36:13 for Cutlass by  <a href="http://www.doxygen.org/index.html">
|
||||
<img class="footer" src="doxygen.png" alt="doxygen"/>
|
||||
</a> 1.8.14
|
||||
</small></address>
|
||||
</body>
|
||||
</html>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user