diff --git a/CMake/bin2hex.cmake b/CMake/bin2hex.cmake
new file mode 100644
index 00000000..603c9a6c
--- /dev/null
+++ b/CMake/bin2hex.cmake
@@ -0,0 +1,26 @@
+# A small utility function which generates a C-header from an input file
+function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
+ FILE(READ "${FILENAME}" HEX_INPUT HEX)
+ if (${ZERO_TERMINATED})
+ string(APPEND HEX_INPUT "00")
+ endif()
+
+ string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
+ string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
+
+ set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")
+
+ set(${OUTPUT_STRING} "${HEX_OUTPUT}" PARENT_SCOPE)
+endfunction()
+
+message("Create header file for ${FILE_IN}")
+message("Create header file for ${FILE_OUT}")
+file_to_c_string(${FILE_IN} ${VARIABLE_NAME} OUTPUT_STRING ZERO_TERMINATED)
+
+set(RESULT "#pragma once\n")
+string(APPEND RESULT "namespace cutlass {\n")
+string(APPEND RESULT "namespace nvrtc {\n")
+string(APPEND RESULT "${OUTPUT_STRING}")
+string(APPEND RESULT "} // namespace nvrtc\n")
+string(APPEND RESULT "} // namespace cutlass\n")
+file(WRITE "${FILE_OUT}" "${RESULT}")
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 00000000..5a53fae5
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,182 @@
+# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without modification, are permitted
+# provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright notice, this list of
+# conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice, this list of
+# conditions and the following disclaimer in the documentation and/or other materials
+# provided with the distribution.
+# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+# to endorse or promote products derived from this software without specific prior written
+# permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+cmake_minimum_required(VERSION 3.3.0)
+
+set(CUTLASS_LANGUAGES CXX)
+
+# CMake 3.9.0 has native support for CUDA without the need of the CUDA package. Use it!
+if(WIN32 AND NOT ${CMAKE_VERSION} VERSION_LESS "3.9.0")
+ list(APPEND CUTLASS_LANGUAGES CUDA)
+ set(CUTLASS_NATIVE_CUDA TRUE)
+
+ macro(cutlass_add_executable)
+ add_executable(${ARGN})
+ endmacro()
+else()
+ # FindCUDA fails to detect VS 2017 due to a changed directory format of the toolkits.
+ # For this configuration we need CMake >= 3.9.0 to use the native CUDA support.
+ if (WIN32 AND MSVC_VERSION GREATER 1800)
+ message(FATAL_ERROR "Please upgrade CMake to version >= 3.9.0 to support Visual Studio 2017 or higher")
+ endif()
+
+ # Fall back to the FindCUDA version to create an executable with CUDA files
+ macro(cutlass_add_executable)
+ cuda_add_executable(${ARGN})
+ endmacro()
+endif()
+
+project(CUTLASS ${CUTLASS_LANGUAGES})
+
+# check if the configuration is supported
+if( NOT CMAKE_SIZEOF_VOID_P EQUAL 8 )
+ message(FATAL_ERROR "CUTLASS requires a 64-bit compiler!")
+endif()
+
+find_package(CUDA)
+find_package(Doxygen QUIET)
+
+# By default we want to build in Release mode to ensure that we're getting best performance
+if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
+ set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
+ # We do support Debug or Release builds
+ set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release")
+endif()
+
+if(WIN32)
+ # On Windows we link against the shared (DLL) runtime. Change gtest settings to match this.
+ set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE)
+endif()
+
+if (WIN32)
+ # Enable more warnings and treat as errors
+ string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
+
+ # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
+ string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
+
+ # Verbose option
+ if (${CUTLASS_NVCC_VERBOSE})
+ string(APPEND NVCC_FLAGS " -v")
+ endif()
+endif(WIN32)
+
+# Configure CUDA options
+set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.")
+set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
+
+foreach(ARCH ${CUTLASS_NVCC_ARCHS})
+ string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
+endforeach()
+
+
+if (CUTLASS_NVCC_KEEP)
+ string(APPEND NVCC_FLAGS " -keep")
+endif()
+
+if (WIN32 AND CUTLASS_NATIVE_CUDA)
+ string(APPEND NVCC_FLAGS_RELEASE " -lineinfo")
+else()
+ string(APPEND NVCC_FLAGS " -lineinfo")
+endif()
+
+if (UNIX)
+ string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
+endif()
+
+string(APPEND NVCC_FLAGS_DEBUG " -g")
+string(APPEND NVCC_FLAGS_RELEASE " -O3")
+
+# define NDEBUG for release mode to disable assertions
+string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
+
+if (CUTLASS_NATIVE_CUDA)
+ set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
+ set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
+ set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
+else()
+ set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
+ set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
+ set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
+endif()
+
+#
+# The following items should eventually be pushed into cutlass/CMakeLists.txt
+#
+
+# GLOB for CUTLASS header files. Should we use a static list instead?
+file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
+file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
+file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
+file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
+
+source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
+source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
+source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
+source_group("cutlass" FILES ${CUTLASS_CORE})
+
+add_library(CUTLASS INTERFACE)
+include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
+target_sources(CUTLASS INTERFACE
+ ${CUTLASS_GEMM}
+ ${CUTLASS_UTIL}
+ ${CUTLASS_DEVICE}
+ ${CUTLASS_CORE}
+)
+
+target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
+
+# Create a custom target to ensure that the CUTLASS sources are visible in an IDE
+add_custom_target(cutlass_ide SOURCES
+ ${CUTLASS_GEMM}
+ ${CUTLASS_UTIL}
+ ${CUTLASS_DEVICE}
+ ${CUTLASS_CORE}
+)
+# Doxygen is available. Generate documentation
+if (DOXYGEN_FOUND)
+ # DOT is available. Enable graph generation in the documentation
+ if (DOXYGEN_DOT_EXECUTABLE)
+ set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
+ else()
+ set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
+ endif()
+
+ if (CUTLASS_ENABLE_DOXYGEN_DOT)
+ set(HAVE_DOT "YES")
+ else()
+ set(HAVE_DOT "NO")
+ endif()
+
+ # Add custom target for Doxygen.
+ add_custom_target(cutlass_docs ${CMAKE_COMMAND} -E env
+ "DOT_PATH=${DOXYGEN_DOT_EXECUTABLE}"
+ "HAVE_DOT=${HAVE_DOT}"
+ ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile
+ WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
+ VERBATIM
+ )
+endif()
+
+
+#add_subdirectory(examples/gemm)
+add_subdirectory(tools)
diff --git a/Doxyfile b/Doxyfile
index 15650b20..51cec529 100644
--- a/Doxyfile
+++ b/Doxyfile
@@ -58,7 +58,7 @@ PROJECT_LOGO =
# entered, it will be relative to the location where doxygen was started. If
# left blank the current directory will be used.
-OUTPUT_DIRECTORY = doxygen
+OUTPUT_DIRECTORY = docs
# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and
@@ -218,7 +218,8 @@ TAB_SIZE = 4
# "Side Effects:". You can put \n's in the value part of an alias to insert
# newlines.
-ALIASES =
+#ALIASES += "concept{1}=@ingroup \1\n@par Implemented concepts:\n@ref \1"
+ALIASES += "concept{1}=@ingroup \1"
# This tag can be used to specify a number of word-keyword mappings (TCL only).
# A mapping has the form "name=value". For example adding "class=itcl::class"
@@ -396,7 +397,7 @@ LOOKUP_CACHE_SIZE = 0
# normally produced when WARNINGS is set to YES.
# The default value is: NO.
-EXTRACT_ALL = NO
+EXTRACT_ALL = YES
# If the EXTRACT_PRIVATE tag is set to YES all private members of a class will
# be included in the documentation.
@@ -733,7 +734,7 @@ WARN_LOGFILE =
# spaces.
# Note: If this tag is empty the current directory is searched.
-INPUT = cutlass cutlass/gemm cutlass/util
+INPUT = cutlass
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
@@ -759,7 +760,7 @@ FILE_PATTERNS =
# be searched for input files as well.
# The default value is: NO.
-RECURSIVE = NO
+RECURSIVE = YES
# The EXCLUDE tag can be used to specify files and/or directories that should be
# excluded from the INPUT source files. This way you can easily exclude a
@@ -2032,7 +2033,7 @@ HIDE_UNDOC_RELATIONS = YES
# set to NO
# The default value is: NO.
-HAVE_DOT = NO
+HAVE_DOT = $(HAVE_DOT)
# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
# to run in parallel. When set to 0 doxygen will base this on the number of
@@ -2204,7 +2205,7 @@ INTERACTIVE_SVG = NO
# found. If left blank, it is assumed the dot tool can be found in the path.
# This tag requires that the tag HAVE_DOT is set to YES.
-DOT_PATH =
+DOT_PATH = $(DOT_PATH)
# The DOTFILE_DIRS tag can be used to specify one or more directories that
# contain dot files that are included in the documentation (see the \dotfile
diff --git a/README.md b/README.md
index a13c9a4c..05a0d3a3 100644
--- a/README.md
+++ b/README.md
@@ -1,106 +1,213 @@
-
+
-# Introduction
+# CUTLASS 1.0
-CUTLASS is a collection of CUDA C++ template abstractions for implementing
-high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
-It incorporates strategies for hierarchical decomposition and data movement similar
-to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
+CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing
+high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
+It incorporates strategies for hierarchical decomposition and data movement similar
+to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
reusable, modular software components abstracted by C++ template classes. These
-thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
-and tuned via custom tiling sizes, data types, and other algorithmic policy. The
-resulting flexibility simplifies their use as building blocks within custom kernels
+thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
+and tuned via custom tiling sizes, data types, and other algorithmic policy. The
+resulting flexibility simplifies their use as building blocks within custom kernels
and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
-mixed-precision computations, providing specialized data-movement and
-multiply-accumulate abstractions for 8-bit integer, half-precision floating
-point (FP16), single-precision floating point (FP32), and double-precision floating
+mixed-precision computations, providing specialized data-movement and
+multiply-accumulate abstractions for 8-bit integer, half-precision floating
+point (FP16), single-precision floating point (FP32), and double-precision floating
point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targeting
-the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
+the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
and beyond.
-For more exposition, see our Parallel Forall blog post [CUTLASS: Fast Linear Algebra
-in CUDA C++](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda).
+CUTLASS 1.0 has changed substantially from our preview release described in
+the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda)
+post. We have decomposed the structure of the GEMM computation into deeper, structured
+primitives for loading data, computing predicate masks, streaming data at each level of
+the GEMM hierarchy, and updating the output matrix.
+
+CUTLASS 1.0 is described in the [Doxygen documentation](https://github.com/NVIDIA/cutlass/docs)
+and our talk at the GPU Technology Conference 2018 (login required).
# Performance
-

+
+
+CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
+they exhibit performance comparable to cuBLAS for scalar GEMM
+computations. The above figure shows CUTLASS performance relative to cuBLAS
+for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
+when compiled with CUDA 9.2.
+
+# Compatibility
+
+CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
+
+|**Operating System** | **Compiler** |
+|-----------------|----------|
+| Windows 10 | Microsoft Visual Studio 2015|
+| | Microsoft Visual Studio 2017|
+| Ubuntu 14.04 | GCC 4.8.2 |
+| Ubuntu 16.04 | GCC 5.4.0 |
+
+
+CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
+any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
+
+|**GPU**|
+|---|
+|NVIDIA GeForce 1080|
+|NVIDIA TitanXP|
+|NVIDIA Tesla P100|
+|NVIDIA Tesla V100|
+|NVIDIA TitanV|
+
+
+# Building CUTLASS
+
+CUTLASS is a header-only template library and does not need to be built to be used by other
+projects. However, we distribute extensive unit tests and utility programs to demonstrate
+CUTLASS. These instructions are for building those test programs.
+
+CUTLASS's unit tests depend on Google Test which exists as a git submodule. You can fetch
+submodules as follows.
+
+```
+$ git submodule update --init --recursive
+```
+
+CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
+for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify
+the architectures to build CUTLASS for by changing the CMake configuration setting
+`CUTLASS_NVCC_ARCHS`.
+
+Create a build directory within the CUTLASS project, then run CMake once.
+
+```
+$ mkdir build && cd build
+$ cmake ..
+```
+
+Compile the CUTLASS project by running Make. Include the -j argument to compile sources in
+parallel and speed up the build process.
+
+```
+$ make -j12
+...
+$
+```
+
+Verify CUTLASS has been built correctly by running the unit tests from the build/ directory.
+
+```
+$ ./tools/test/unit/cutlass_unit_test
+...
+...
+...
+[----------] Global test environment tear-down
+[==========] 481 tests from 24 test cases ran. (5954 ms total)
+[ PASSED ] 481 tests.
+```
+
+All tests should pass, though the exact number of tests may vary over time.
-CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
-they exhibit performance comparable to cuBLAS for scalar GEMM
-computations. The above figure shows CUTLASS performance relative to cuBLAS
-for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Tesla V100 GPU
-when compiled with CUDA 9.0.
# Project Structure
CUTLASS is arranged as a header-only library with several example test programs
-that demonstrate instantiating a GEMM task within a CUDA kernel. Comments inline
-with the source explain the individual components.
+that demonstrate instantiating a GEMM task within a CUDA kernel. The Doxygen documentation
+provides a complete list of files, classes, and template concepts defined in the CUTLASS
+project. A brief summary is described below.
-The repository is organized in the following arrangement.
-
- cutlass/ Root of header-only source library for matrix multiply
- gemm/ Implementation of GEMM __device__ code and supporting components
- util/ Utility components for CUDA device-side CUDA development
-
-A test program is provided to illustrate the use of CUTLASS. This is implemented
-in the following directory.
-
- cutlass_test Root of test programs depicting CUTLASS kernels
- util/ Utilities
- gemm.cu Simple example calling CUTLASS and CUBLAS GEMM kernels
- Makefile Build script for test programs
-
-
-# Makefile usage
-
-There are different sample targets for different GEMM data types and
-transposititions. Be sure to specify your target architecture.
-
- make sm=<60|61|70> \
- [transpose=] [verbose=<0|1>] [keep=<0|1>]
-
-
-# Program usage
-
- Program usage:
-
- gemm_
- [--help]
- [--schmoo=<#schmoo-samples> || --m= --n= --k=]
- [--i=]
- [--device=]
- [--alpha= --beta=]
-
-
-# Open Source License
-
-CUTLASS is released by NVIDIA Corporation under the "New BSD" open-source license:
+The CUTLASS library is defined in the cutlass/ directory and consists of CUDA C++ template
+classes and other definitions for implementing efficient GPU GEMM kernels. A set of core
+classes and templates define basic primitives that are then applied to compute GEMM via
+templates in the cutlass/gemm directory.
```
-Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in the
- documentation and/or other materials provided with the distribution.
- * Neither the name of the NVIDIA CORPORATION nor the
- names of its contributors may be used to endorse or promote products
- derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
-DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+cutlass/
+ gemm/
+ util/
+
```
+
+Several tools and test programs are also distributed with the CUTLASS library. They are
+contained in the following directories.
+
+```
+tools/
+ test/
+ unit/
+ core/
+ gemm/
+ perf/
+ util/
+
+```
+
+The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
+basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
+
+# Performance Profiling
+
+The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
+Its usage is shown below.
+
+Program usage:
+
+```
+ cutlass_perf_test [options]
+
+ --help
+ --append= If true, appends output to existing CSV file. If false, overwrites.
+ --alpha= Value for alpha to be used in GEMM experiments
+ --beta= Value for beta to be used in GEMM experiments
+ --output= Writes summary of profiling to specified .csv file
+ --iterations= maximum number of iterations to execute when profiling
+ --m=[:max height[:step]] Height of GEMM problem (number of rows of C). May specify a range with optional step size.
+ --n=[:max width[:step]] Width of GEMM problem (number of columns of C). May specify a range with optional step size.
+ --k=[:max depth[:step]] Size of inner dimension of A and B. May specify a range with optional step size.
+ --kernels=<{s|d|h|i|wmma}gemm_{nn,nt,tn,tt}> Select GEMM datatype and layout to use for tests
+ --peak= If true, only reports peak performance per kernel after profiling specified problem space.
+ --seed= Random seed used by the random number generator in initializing input matrices.
+ --tags= Inserts leading columns in output table and uniform values for each column. Useful for generating pivot tables.
+
+
+ Example usage:
+
+ # Runs one problem size for all kernels
+ $ ./tools/test/perf/cutlass_perf_test --m=10240 --n=1024 --k=1024
+
+ # Varies GEMM K dimension for SGEMM and IGEMM with column-major multiplicands
+ $ ./tools/test/perf/cutlass_perf_test --m=10240 --n=4096 --k=1024:8192:128 --kernels=sgemm_nn,igemm_nn
+```
+
+# About
+
+CUTLASS is released by NVIDIA Corporation as Open Source software under the
+3-clause "New" BSD license.
+
+
+# Copyright
+
+Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above copyright notice, this list of
+ conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice, this list of
+ conditions and the following disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ to endorse or promote products derived from this software without specific prior written
+ permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/clang-format.sh b/clang-format.sh
new file mode 100755
index 00000000..b2570d91
--- /dev/null
+++ b/clang-format.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+set -e
+
+function formatFiles {
+ for f in `find "$1" -type f -name "*.$2"` ; do
+ COMMAND="clang-format -i $f"
+ echo $COMMAND
+ $COMMAND
+ done
+}
+
+formatFiles "cutlass" "h"
+formatFiles "tools/test" "h"
+formatFiles "tools/test" "cpp"
+formatFiles "tools/util" "h"
+
diff --git a/common.mk b/common.mk
deleted file mode 100644
index 672ea5b2..00000000
--- a/common.mk
+++ /dev/null
@@ -1,181 +0,0 @@
-#/******************************************************************************
-# * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
-# *
-# * Redistribution and use in source and binary forms, with or without
-# * modification, are permitted provided that the following conditions are met:
-# * * Redistributions of source code must retain the above copyright
-# * notice, this list of conditions and the following disclaimer.
-# * * Redistributions in binary form must reproduce the above copyright
-# * notice, this list of conditions and the following disclaimer in the
-# * documentation and/or other materials provided with the distribution.
-# * * Neither the name of the NVIDIA CORPORATION nor the
-# * names of its contributors may be used to endorse or promote products
-# * derived from this software without specific prior written permission.
-# *
-# * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-# * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-# * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
-# * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-# * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-# * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-# * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-# *
-#******************************************************************************/
-
-
-#-------------------------------------------------------------------------------
-# Commandline Options
-#-------------------------------------------------------------------------------
-
-# sm= Compute-capability to compile for, e.g., "sm=200,300,350" (SM2.0 by default).
-
-COMMA := ,
-ifdef sm
- SM_ARCH := $(subst $(COMMA),-,$(sm))
-else
- $(error Please specify SM architecture makefile argument: "sm=XX")
-endif
-
-ifeq (70, $(findstring 70, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_70,code=\"sm_70,compute_70\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_70
-endif
-ifeq (62, $(findstring 62, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_62,code=\"sm_62,compute_62\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_62
-endif
-ifeq (61, $(findstring 61, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_61,code=\"sm_61,compute_61\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_61
-endif
-ifeq (60, $(findstring 60, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_60,code=\"sm_60,compute_60\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_60
-endif
-ifeq (52, $(findstring 52, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_52,code=\"sm_52,compute_52\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_52
-endif
-ifeq (37, $(findstring 37, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_37,code=\"sm_37,compute_37\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_37
-endif
-ifeq (35, $(findstring 35, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_35,code=\"sm_35,compute_35\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_35
-endif
-ifeq (30, $(findstring 30, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_30,code=\"sm_30,compute_30\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_30
-endif
-ifeq (21, $(findstring 21, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_20,code=\"sm_21,compute_20\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_21
-endif
-ifeq (20, $(findstring 20, $(SM_ARCH)))
- SM_TARGETS += -gencode=arch=compute_20,code=\"sm_20,compute_20\"
- CLANG_SM_TARGETS += --cuda-gpu-arch=sm_20
-endif
-
-
-# [verbose=<0|1>] Verbose toolchain output from nvcc option
-ifeq ($(verbose), 1)
- NVCCFLAGS += -v
- CLANG_CFLAGS += -v
-endif
-
-
-# [keep=<0|1>] Keep intermediate compilation artifacts option
-ifeq ($(keep), 1)
- NVCCFLAGS += -keep
- CLANG_CFLAGS += --save-temps
-endif
-
-
-# [debug=<0|1>] Generate debug mode code
-ifeq ($(debug), 1)
- NVCCFLAGS += -G
- CLANG_CFLAGS += --cuda-noopt-device-debug
-endif
-
-
-#-------------------------------------------------------------------------------
-# Compiler and compilation platform
-#-------------------------------------------------------------------------------
-
-BASE_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
-
-NVCC := "$(shell which nvcc)"
-ifdef nvccver
- NVCC_VERSION := $(nvccver)
-else
- NVCC_VERSION := $(strip $(shell nvcc --version | grep release | sed 's/.*release //' | sed 's/,.*//'))
-endif
-
-# Detect OS
-OSUPPER := $(shell uname -s 2>/dev/null | tr [:lower:] [:upper:])
-
-# Default flags: verbose kernel properties (regs, smem, cmem, etc.); runtimes for compilation phases
-NVCCFLAGS += -O3 -Xptxas -v
-CLANG_CFLAGS += -O3 -Xcuda-ptxas -v
-ifeq (WIN_NT, $(findstring WIN_NT, $(OSUPPER)))
- # For MSVC
-
- # Enable more warnings and treat as errors
- NVCCFLAGS += -Xcompiler /W3 -Xcompiler /WX
-
- # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
- NVCCFLAGS += -Xcompiler /fp:strict
-
- # Compiler
- CC := cl
-
- # Multithreaded runtime
- NVCCFLAGS += -Xcompiler /MT
-
- CUDART_CYG := "$(shell dirname $(NVCC))/../lib/x64/cudart.lib"
- CUDART := "$(shell cygpath -w $(CUDART_CYG))"
-
-else
- # For g++
-
- # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
- #NVCCFLAGS += -Xcompiler -ffloat-store
-
- # Compiler
- CC := g++
-
- CUDART := "$(shell dirname $(NVCC))/../lib64/libcudart_static.a"
-
-endif
-
-# compiler=clang Enables compilation with clang.
-
-ifeq ($(compiler), clang)
- # NVCC_VERSION is used as the proxy for the CUDA version.
- BIN_SUFFIX := sm$(SM_ARCH)_clang_cuda_$(NVCC_VERSION)
- # Clangs needs few extra flags to point it to CUDA SDK
- # and link the binaries with CUDA runtime.
- CUDA_BASE=$(realpath $(join $(dir $(shell which nvcc)), ..))
- CLANG_CFLAGS += --cuda-path=$(CUDA_BASE)
- LIBINC += -L$(CUDA_BASE)/lib64 -Wl,-rpath=$(CUDA_BASE)/lib64
- LIBS += -lcudart
-
- # Replace NVCC and its options with clang++.
- NVCC = clang++
- NVCCFLAGS = $(CLANG_CFLAGS)
- SM_TARGETS = $(CLANG_SM_TARGETS)
-else
- # Suffix to append to each binary
- BIN_SUFFIX := sm$(SM_ARCH)_nvcc_$(NVCC_VERSION)
-endif
-
-
-#-------------------------------------------------------------------------------
-# Function for computing dependency Lists
-#-------------------------------------------------------------------------------
-
-rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
diff --git a/cutlass/convert.h b/cutlass/convert.h
new file mode 100644
index 00000000..933d68a8
--- /dev/null
+++ b/cutlass/convert.h
@@ -0,0 +1,102 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*!
+ \file
+ \brief Defines conversion operations among Fragments of different base type.
+*/
+#pragma once
+
+#include
+
+namespace cutlass {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct Convert {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct Convert, Fragment > {
+ /// The input fragment.
+ typedef Fragment InputFragment;
+ /// The output fragment.
+ typedef Fragment OutputFragment;
+
+ /// Ctor.
+ CUTLASS_DEVICE Convert() {}
+
+ /// Transform a fragment.
+ CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
+ transform(src, 0, dst);
+ }
+
+ /// Transform a fragment.
+ template
+ CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
+ for (int i = 0; i < kScalars_; ++i) {
+ dst[i] = static_cast(src[i + offset]);
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct Copy {
+ /// The input fragment.
+ typedef Fragment_ InputFragment;
+ /// The output fragment.
+ typedef Fragment_ OutputFragment;
+
+ /// Ctor.
+ CUTLASS_DEVICE Copy() {}
+
+ /// Transform a fragment.
+ CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
+
+ /// Transform a fragment.
+ template
+ CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
+ if (sizeof(typename Fragment_::Element) == 8) {
+ uint64_t const* src_ptr = reinterpret_cast(&src[offset]);
+ uint64_t* dst_ptr = reinterpret_cast(&dst[0]);
+ for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
+ dst_ptr[i] = src_ptr[i];
+ }
+ } else {
+ uint32_t const* src_ptr = reinterpret_cast(&src[offset]);
+ uint32_t* dst_ptr = reinterpret_cast(&dst[0]);
+ for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
+ dst_ptr[i] = src_ptr[i];
+ }
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass
diff --git a/cutlass/coord.h b/cutlass/coord.h
new file mode 100644
index 00000000..431c9bf1
--- /dev/null
+++ b/cutlass/coord.h
@@ -0,0 +1,287 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix
+*/
+
+#pragma once
+
+#include
+
+namespace cutlass {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Describes identity elements
+struct Identity {
+ /// Enumeration describing identity elements. Value assignments are significant.
+ /// Feel free to add or multiply by these, respectively.
+ enum Kind { Additive = 0, Multiplicative = 1 };
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Statically-sized array specifying Coords within a tensor
+template
+struct Coord {
+ //
+ // Type and constant definitions
+ //
+
+ static int const N = N_;
+
+ //
+ // Data members
+ //
+
+ /// Indices
+ int idx[N];
+
+ //
+ // Methods
+ //
+
+ /// Default ctor initializes uniformly
+ CUTLASS_HOST_DEVICE
+ Coord(int value = 0) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] = value;
+ }
+ }
+
+ /// Constructs from an array of integers
+ CUTLASS_HOST_DEVICE
+ Coord(int _idx[]) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] = _idx[i];
+ }
+ }
+
+ /// Element-wise addition
+ CUTLASS_HOST_DEVICE
+ Coord operator+(Coord const& b) const {
+ Coord c;
+ for (int i = 0; i < N; ++i) {
+ c.idx[i] = idx[i] + b.idx[i];
+ }
+ return c;
+ }
+
+ /// Element-wise subtraction
+ CUTLASS_HOST_DEVICE
+ Coord operator-(Coord const& b) const {
+ Coord c;
+ for (int i = 0; i < N; ++i) {
+ c.idx[i] = idx[i] - b.idx[i];
+ }
+ return c;
+ }
+
+ /// Element-wise multiplication
+ CUTLASS_HOST_DEVICE
+ Coord operator*(Coord const& b) const {
+ Coord c;
+ for (int i = 0; i < N; ++i) {
+ c.idx[i] = idx[i] * b.idx[i];
+ }
+ return c;
+ }
+
+ /// Element-wise division
+ CUTLASS_HOST_DEVICE
+ Coord operator/(Coord const& b) const {
+ Coord c;
+ for (int i = 0; i < N; ++i) {
+ c.idx[i] = idx[i] / b.idx[i];
+ }
+ return c;
+ }
+
+ /// In-place addition
+ CUTLASS_HOST_DEVICE
+ Coord& operator+=(Coord const& b) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] += b.idx[i];
+ }
+ return *this;
+ }
+
+ /// In-place subtraction
+ CUTLASS_HOST_DEVICE
+ Coord& operator-=(Coord const& b) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] -= b.idx[i];
+ }
+ return *this;
+ }
+
+ /// In-place multiplication
+ CUTLASS_HOST_DEVICE
+ Coord& operator*=(Coord const& b) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] *= b.idx[i];
+ }
+ return *this;
+ }
+
+ /// In-place division
+ CUTLASS_HOST_DEVICE
+ Coord& operator/=(Coord const& b) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] /= b.idx[i];
+ }
+ return *this;
+ }
+
+ /// Member access operator
+ CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
+
+ /// Member access operator
+ CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
+
+ /// Computes the dot product of two Coord instances
+ template
+ CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
+ for (int i = 0; i < N; ++i) {
+ sum += idx[i] * b.idx[i];
+ }
+ return sum;
+ }
+
+ /// Computes the dot product of two Coord instances
+ template
+ CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
+ T sum = T(0);
+ for (int i = 0; i < N; ++i) {
+ sum += idx[i] * b.idx[i];
+ }
+ return sum;
+ }
+
+ /// Gets the index of a given Coord element
+ template
+ CUTLASS_HOST_DEVICE int& at() {
+ return idx[Dim];
+ }
+
+ /// Access via index; may limit unrolling potential
+ CUTLASS_HOST_DEVICE
+ int& at(int dim) { return idx[dim]; }
+
+ /// Gets the index of a given Coord element
+ template
+ CUTLASS_HOST_DEVICE int const& at() const {
+ return idx[Dim];
+ }
+
+ /// Access via index; may limit unrolling potential
+ CUTLASS_HOST_DEVICE
+ int const& at(int dim) const { return idx[dim]; }
+
+ /// Determines if two Coord<> objects are equal
+ CUTLASS_HOST_DEVICE
+ bool operator==(Coord const& b) const {
+ bool equal = true;
+ for (int i = 0; equal && i < N; ++i) {
+ equal = (idx[i] == b.idx[i]);
+ }
+ return equal;
+ }
+
+ /// Not equal
+ CUTLASS_HOST_DEVICE
+ bool operator!=(Coord const& b) const { return !(*this == b); }
+
+ /// Clamps a coordinate to a range specified by maximum and minimum values
+ CUTLASS_HOST_DEVICE
+ Coord& clamp(Coord const& max, Coord const& min = Coord()) {
+ for (int i = 0; i < N; ++i) {
+ idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
+ }
+ return *this;
+ }
+
+ /// Returns the product of all elements
+ CUTLASS_HOST_DEVICE
+ int count() const {
+ int product = idx[0];
+ for (int i = 1; i < N; ++i) {
+ product *= idx[i];
+ }
+ return product;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Helper to make a 2-element coordinate
+CUTLASS_HOST_DEVICE
+Coord<1> make_Coord(int _0) {
+ int values[1] = {_0};
+ return Coord<1>(values);
+}
+
+/// Helper to make a 2-element coordinate
+CUTLASS_HOST_DEVICE
+Coord<2> make_Coord(int _0, int _1) {
+ int values[2] = {_0, _1};
+ return Coord<2>(values);
+}
+
+/// Helper to make a 3-element coordinate
+CUTLASS_HOST_DEVICE
+Coord<3> make_Coord(int _0, int _1, int _2) {
+ int values[3] = {_0, _1, _2};
+ return Coord<3>(values);
+}
+
+/// Helper to make a 4-element coordinate
+CUTLASS_HOST_DEVICE
+Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
+ int values[4] = {_0, _1, _2, _3};
+ return Coord<4>(values);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Getter
+CUTLASS_HOST_DEVICE
+Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
+
+/// Getter
+CUTLASS_HOST_DEVICE
+Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
+
+/// Getter
+CUTLASS_HOST_DEVICE
+Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
+
+/// Getter
+CUTLASS_HOST_DEVICE
+Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass
diff --git a/cutlass/core_io.h b/cutlass/core_io.h
new file mode 100644
index 00000000..cceea4c0
--- /dev/null
+++ b/cutlass/core_io.h
@@ -0,0 +1,44 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+#pragma once
+
+/*! \file
+ \brief Helpers for printing cutlass/core objects
+*/
+
+#pragma once
+
+#include
+#include
+
+#include
+
+template
+std::ostream& operator<<(std::ostream& out, cutlass::Coord const& coord) {
+ for (int i = 0; i < Rank; ++i) {
+ out << (i ? ", " : "") << coord.idx[i];
+ }
+ return out;
+}
diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h
new file mode 100644
index 00000000..1e428b16
--- /dev/null
+++ b/cutlass/cutlass.h
@@ -0,0 +1,73 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+
+/*! \file
+ \brief Basic include for CUTLASS macros
+*/
+
+#pragma once
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define CUTLASS_MAJOR 1
+#define CUTLASS_MINOR 0
+#define CUTLASS_PATCH 0
+#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
+
+#ifdef __NVCC__
+#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
+#define CUTLASS_DEVICE __forceinline__ __device__
+#elif defined(__CUDACC_RTC__)
+#define CUTLASS_HOST_DEVICE __forceinline__ __device__
+#define CUTLASS_DEVICE __forceinline__ __device__
+#else
+#define CUTLASS_HOST_DEVICE
+// CUTLASS_DEVICE is an error if not compiling device code
+#endif
+
+// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
+#if defined(__CUDA_ARCH__)
+#if defined(_MSC_VER)
+#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
+#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
+#else
+#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
+#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
+#endif
+#else
+#define CUTLASS_PRAGMA_UNROLL
+#define CUTLASS_PRAGMA_NO_UNROLL
+#endif
+
+#define CUTLASS_ASSERT(x) assert(x)
+
+namespace cutlass {
+
+/// NVIDIA GPU Warp size
+static const int kWarpSize = 32;
+
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/fragment.h b/cutlass/fragment.h
new file mode 100644
index 00000000..53fa380c
--- /dev/null
+++ b/cutlass/fragment.h
@@ -0,0 +1,278 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines Fragment, a statically-sized array for storing parts of matrices within a
+ thread's registers.
+*/
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace cutlass {
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/*!@defgroup fragment_concept Fragment Concept
+@{
+
+\ref fragment_concept is a statically sized array for storing parts of tiles held by individual CUDA
+threads.
+
+@par \ref fragment_concept
+ Types satisfying \ref fragment_concept define the following members
+ - Element - type of each access held within the fragment
+ - kElements - number of elements stored by the fragment
+ - clear() - overwrites the fragment storage with zeros
+ - Element & operator[](int i) - by-reference access of the ith element
+ - Element const & operator[](int i) const - const by-reference access of the ith element
+@}
+*/
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/*!@defgroup fragment_iterator_concept Fragment Iterator Concept
+@{
+
+\ref fragment_iterator_concept provides structured access to the elements within a fragment with an
+optional bitcast to the desired access type
+
+@par \ref fragment_iterator_concept
+ Types satisfying \ref fragment_iterator_concept define the following members
+ - AccessType& operator[](int i) - provides access to the ith element of the fragment
+ - AccessType& at(int d, int h, int w, int c) - applies \ref layout_concept to fragment and
+provides access to element at (d, h, w, c)
+
+@}
+*/
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct StorageType {
+ typedef uint64_t Type;
+};
+template <>
+struct StorageType<4> {
+ typedef uint32_t Type;
+};
+template <>
+struct StorageType<2> {
+ typedef uint16_t Type;
+};
+template <>
+struct StorageType<1> {
+ typedef uint8_t Type;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+* @brief A template defining \ref fragment_concept
+* @concept{fragment_concept}
+*/
+template
+struct Fragment : public AlignedStruct {
+ /// Make sure the alignment makes sense wrt the size of elements.
+ static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
+ /// Alignment must be a power of two
+ static_assert(is_pow2::value, "Alignment must be a power of two");
+
+ /// This class.
+ typedef Fragment This_;
+ /// The element.
+ typedef Element_ Element;
+ /// The number of elements.
+ static int const kElements = kElements_;
+
+ /// Clear a fragment.
+ CUTLASS_DEVICE void clear() {
+ // Avoid element-wise access for sub 32b element type
+ if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
+ uint64_t* ptr = reinterpret_cast(storage);
+ for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
+ ptr[i] = uint64_t(0);
+ }
+ } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
+ uint32_t* ptr = reinterpret_cast(storage);
+ for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
+ ptr[i] = uint32_t(0);
+ }
+ } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
+ uint16_t* ptr = reinterpret_cast(storage);
+ for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
+ ptr[i] = uint16_t(0);
+ }
+ } else {
+ for (int i = 0; i < kElements; ++i) {
+ storage[i] = 0;
+ }
+ }
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE Element& operator[](int i) {
+ assert(i < kElements_);
+ return reinterpret_cast(storage)[i];
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE Element const& operator[](int i) const {
+ assert(i < kElements_);
+ return reinterpret_cast(storage)[i];
+ }
+
+ private:
+ /// Storage type to use for Elements
+ typedef typename StorageType::Type StorageType;
+
+ /// Number of elements in the storage
+ static int const kStorageCount =
+ (sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
+ /// The storage.
+ StorageType storage[kStorageCount];
+
+ /// Ensure that there's enough storage for all elements
+ static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/**
+* @brief A template defining \ref fragment_iterator_concept
+* @concept{fragment_iterator_concept}
+*/
+template
+struct FragmentIterator {
+ /// This class.
+ typedef FragmentIterator This_;
+ /// The fragment.
+ typedef Fragment_ Fragment;
+ /// The number of iterations.
+ typedef Iterations_ Iterations;
+ /// The access type.
+ typedef AccessType_ AccessType;
+
+ /// The element.
+ typedef typename Fragment::Element Element;
+ /// The number of elements per access.
+ static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
+ /// The shape of the the fragment.
+ typedef typename ShapeMul >::Shape FragmentShape;
+ /// The linear strides for iterations.
+ typedef typename ShapeStrides::Shape Strides;
+
+ /// Ctor.
+ template
+ CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
+ : pointer(reinterpret_cast(&fragment[offset])) {
+ static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
+ int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
+ return reinterpret_cast(pointer[imm]);
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
+ int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
+ return reinterpret_cast(pointer[imm]);
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType const& operator[](int i) const {
+ return reinterpret_cast(pointer[i * kElementsPerAccess]);
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType& operator[](int i) {
+ return reinterpret_cast(pointer[i * kElementsPerAccess]);
+ }
+
+ /// Is the iterator valid?
+ CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
+
+ /// The pointer.
+ Element* pointer;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct FragmentConstIterator {
+ /// This class.
+ typedef FragmentIterator This_;
+ /// The fragment.
+ typedef Fragment_ Fragment;
+ /// The number of iterations.
+ typedef Iterations_ Iterations;
+ /// The access type.
+ typedef AccessType_ AccessType;
+
+ /// The element.
+ typedef typename Fragment::Element Element;
+ /// The number of elements per access.
+ static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
+ /// The shape of the the fragment.
+ typedef typename ShapeMul >::Shape FragmentShape;
+ /// The linear strides for iterations.
+ typedef typename ShapeStrides::Shape IterationsStrides;
+
+ /// Ctor.
+ template
+ CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
+ : pointer(reinterpret_cast(&fragment[offset])) {
+ static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
+ }
+ /// Create from non-constant FragmentIterator
+ CUTLASS_DEVICE FragmentConstIterator(
+ FragmentIterator const& rhs_)
+ : pointer(reinterpret_cast(rhs_.offset)) {}
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
+ int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
+ return reinterpret_cast(pointer[imm]);
+ }
+
+ /// The accessor.
+ CUTLASS_DEVICE AccessType const& operator[](int i) const {
+ return reinterpret_cast(pointer[i * kElementsPerAccess]);
+ }
+
+ /// Is the iterator valid?
+ CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
+
+ /// The pointer.
+ Element const* pointer;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass
diff --git a/cutlass/fragment_load_store.h b/cutlass/fragment_load_store.h
new file mode 100644
index 00000000..a7d272e9
--- /dev/null
+++ b/cutlass/fragment_load_store.h
@@ -0,0 +1,135 @@
+/***************************************************************************************************
+ * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines accessors for loading and storing fragments to memory efficiently.
+*/
+#pragma once
+
+#include
+#include
+
+namespace cutlass {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct FragmentLoad {};
+
+template
+struct FragmentLoad {
+ /// The output type.
+ typedef FragmentElement_ AccessType;
+
+ /// The load function.
+ static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
+ value.load(&pointer[offset], kStride);
+ }
+};
+
+template
+struct FragmentLoad {
+ /// The output type.
+ typedef typename Vectorize::Type AccessType;
+
+ /// The load function.
+ static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
+ Load::load(value, pointer, offset);
+ }
+};
+
+template
+struct FragmentStore {};
+
+template
+struct FragmentStore {
+ /// The input type.
+ typedef FragmentElement_ AccessType;
+
+ /// The store function.
+ static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
+ value.store(&pointer[offset], kStride);
+ }
+};
+
+template
+struct FragmentStore {
+ /// The input type.
+ typedef typename Vectorize::Type AccessType;
+
+ /// The store function.
+ static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
+ Store::store(value, pointer, offset);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} /// namespace cutlass
diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h
new file mode 100644
index 00000000..2d31e793
--- /dev/null
+++ b/cutlass/fragment_multiply_add.h
@@ -0,0 +1,131 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines multiply-add operations on fragments within a thread.
+*/
+#pragma once
+
+#include
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct FragmentMultiplyAdd {
+ /// The shape of the instruction.
+ typedef Shape<1, 1, 1, 1> InstructionShape;
+ /// The type for A.
+ typedef Scalar_ ScalarA;
+ /// The type for B.
+ typedef Scalar_ ScalarB;
+ /// The type for C and D.
+ typedef Scalar_ ScalarC;
+
+ /// Ctor.
+ CUTLASS_DEVICE FragmentMultiplyAdd() {}
+
+ /// Multiply : d = a*b.
+ template
+ CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const& b, Fragment_& d) {
+ for (int j = 0; j < Fragment_::kElements; ++j) {
+ d[j] = a * b[j];
+ }
+ }
+
+ /// Multiply : d = a*b + c.
+ template
+ CUTLASS_DEVICE void multiply_add(Scalar_ a,
+ Fragment_ const& b,
+ Fragment_ const& c,
+ Fragment_& d) {
+ for (int j = 0; j < Fragment_::kElements; ++j) {
+ d[j] = a * b[j] + c[j];
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
+template <>
+struct FragmentMultiplyAdd {
+ /// The shape of the instruction.
+ typedef Shape<1, 1, 1, 1> InstructionShape;
+ /// The type for A.
+ typedef half ScalarA;
+ /// The type for B.
+ typedef half ScalarB;
+ /// The type for C and D.
+ typedef half ScalarC;
+
+ /// Ctor.
+ CUTLASS_DEVICE FragmentMultiplyAdd() {}
+
+ /// Multiply : d = a*b.
+ template
+ CUTLASS_DEVICE void multiply(half a, Fragment_ const& b, Fragment_& d) {
+#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
+ // The input.
+ __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
+ // The output.
+ __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
+
+ // Assemble a half2 from a.
+ __half2 const a_half2 = __half2half2(a);
+
+ for (int i = 0; i < Fragment_::kElements / 2; ++i) {
+ d_half2[i] = __hmul2(a_half2, b_half2[i]);
+ }
+#endif
+ }
+
+ /// Multiply : d = a*b + c.
+ template
+ CUTLASS_DEVICE void multiply_add(half a, Fragment_ const& b, Fragment_ const& c, Fragment_& d) {
+#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
+ // The inputs.
+ __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
+ __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
+ // The output.
+ __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
+
+ // Assemble a half2 from a.
+ __half2 const a_half2 = __half2half2(a);
+
+ for (int i = 0; i < Fragment_::kElements / 2; ++i) {
+ d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]);
+ }
+#endif
+ }
+};
+
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/block_loader.h b/cutlass/gemm/block_loader.h
deleted file mode 100644
index 0c5b3faa..00000000
--- a/cutlass/gemm/block_loader.h
+++ /dev/null
@@ -1,162 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * block-wide tile-loading abstractions
- */
-
-#include "../util/util.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * load_algorithm
- ******************************************************************************/
-
-/**
- * \brief Enumeration of matrix loading algorithms
- */
-struct load_algorithm
-{
- /// \brief Enumerants. See corresponding tag types.
- enum kind_t
- {
- CongruousCopy = 0,
- CrosswiseCopy = 1,
- };
-
- /**
- * \brief Generic tag
- */
- template
- struct any_tag : nv_std::integral_constant {};
-
- /**
- * \brief Copy from a global matrix that is row-major in relation
- * to the local row-major tile
- */
- typedef any_tag contiguous_tag_t;
-
- /**
- * \brief Copy from a global matrix that is column-major in relation
- * to the local row-major tile
- */
- typedef any_tag crosswise_tag_t;
-
-};
-
-
-/******************************************************************************
- * block_loader
- ******************************************************************************/
-
-/**
- * \brief A three-phase data loading abstraction (prefetch, commit, and
- * advance) for iterating over ranges of block-wide matrix tiles.
- *
- * Each iteration sequence produces a KxL (height-by-width) block-wide tile of
- * value_t in shared memory. The layout of the shared
- * block-wide tile is a row-major (L-major) tiling of dp_vector_t items, which are
- * themselves column-major (K-major) vectors of value_t. Its dimensions are:
- * K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
- * L = BlockDpVectorsL
- *
- * NB: This generic class is not directly constructible. Architecture- and
- * algorithm-specific template specializations will provide the API
- * functionality prescribed here.
- *
- */
-template <
- int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
- int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
- int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
- typename value_t, ///< Input matrix value type
- int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
- bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
- typename dp_vector_t, ///< Dot-product vector type along the K-axis
- load_algorithm::kind_t LoadAlgorithm> ///< Algorithm for loading a shared tile of KxL matrix data
-struct block_loader
-{
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- block_loader(
- value_t *d_matrix, ///< Pointer to input matrix
- int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
- int matrix_values_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_values_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
- int2 block_begin_item_coords, ///< Thread block's starting value_t coordinates (l, k) within the input matrix
- int block_end_item_k); ///< Thread block's ending coordinate (k) within the input matrix (one-past)
-
- //-------------------------------------------------------------------------
- // Loader API
- //-------------------------------------------------------------------------
-
- /**
- * Request the current block-wide tile
- */
- void request();
-
-
- /**
- * Advance the loader to the next block-wide tile in the K-axis
- */
- void next();
-
-
- /**
- * Commit the previously-requested block-wide tile to shared memory
- *
- * NB: To facilitate padding for avoiding shared memory bank conflicts, we
- * allow the row stride _BlockDpVectorsL to be arbitrarily bigger than the
- * tile width BlockDpVectorsL.
- */
- template
- void commit(
- dp_vector_t (&scratch_tile)[BlockDpVectorsK][_BlockDpVectorsL]);
-
-};
-
-
-} // namespace gemm
-} // namespace cutlass
-
-
-/******************************************************************************
- * Tail-include specializations that adhere to the block_loader API
- ******************************************************************************/
-
-#include "block_loader_crosswise.h"
-#include "block_loader_congruous_dp1.h"
-#include "block_loader_congruous_idp4.h"
diff --git a/cutlass/gemm/block_loader_congruous_dp1.h b/cutlass/gemm/block_loader_congruous_dp1.h
deleted file mode 100644
index 80dca26d..00000000
--- a/cutlass/gemm/block_loader_congruous_dp1.h
+++ /dev/null
@@ -1,406 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Tile-loading abstraction for thread blocks
- */
-
-#include "../util/util.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * block_loader (CongruousCopy + dp1 specialization)
- ******************************************************************************/
-
-/**
- * \brief A three-phase data loading abstraction (prefetch, commit, and
- * advance) for iterating over ranges of block-wide matrix tiles.
- * (CongruousCopy + dp1 specialization)
- *
- * Each iteration sequence produces a KxL (height-by-width) block-wide tile of
- * value_t in shared memory. The layout of the shared block-wide tile is
- * a row-major (L-major) tiling of singleton "dp1" dp_vector_t items, where
- * dp_vector_t == value_t. Its dimensions are:
- * K = BlockDpVectorsK
- * L = BlockDpVectorsL
- *
- * The data is copied from a corresponding tile of global matrix data whose
- * layout of value_t is also L-major. This constitutes a CongruousCopy
- * between the L-major global tile and the L-major shared tile.
- *
- * NB: Because they are "dp1" singletons, the K-major orientation of
- * dp_vector_t in shared memory is irrelevant, and the L-major global and
- * shared tile layouts are perfectly congruous. As a result, we can increase
- * the granularity of data transfer via vectorization of loads and stores
- * without any intermediate {dis|re}assembly.
- *
- * NB: Consecutive threads within a block are mapped in L-major
- * fashion across a first-set of LDG-vectors of dp_vector_t (value_t) within
- * their global tile. Successive sets of LDG-vectors are then strip-mined
- * as necessary down the K-axis. These discontiguous LDG-vectors comprise the
- * thread's "slice" of the block-wide tile.
- */
-template <
- int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
- int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
- int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
- typename value_t, ///< Input matrix value type
- int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
- bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
->
-struct block_loader<
- BlockThreads,
- BlockDpVectorsK,
- BlockDpVectorsL,
- value_t,
- LeadingDimAlignBytes,
- AllowRaggedTiles,
- value_t, ///< Dot-product vector type along the K-axis (dp1 specialization)
- load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CongruousCopy specialization)
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- /// Dot-product vector type along the K-axis
- typedef value_t dp_vector_t;
-
- enum
- {
- /// Number of value_t in a dp_vector_t
- DpVectorItems = divide_assert::value,
-
- /// Number of dp_vector_t in a block-wide tile
- BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
-
- /// Number of dp_vector_t in a thread-tile
- ThreadDpVectors = divide_assert::value,
- };
-
- /// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
- /// smaller of either ThreadDpVectors or BlockDpVectorsL
- typedef io_vector<
- dp_vector_t,
- __NV_STD_MIN(ThreadDpVectors, BlockDpVectorsL),
- LeadingDimAlignBytes>
- ldg_vector_t;
-
- enum
- {
- /// Number of dp_vector_t per ldg_vector_t
- LdgVectorDpVectors = ldg_vector_t::VectorItems,
-
- /// Number of value_t per ldg_vector_t
- LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
-
-
-
- /// Total number of ldg_vector_t within each block-wide tile
- BlockLdgVectors = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along L-axis
- BlockLdgVectorsL = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along K-axis
- BlockLdgVectorsK = BlockDpVectorsK,
-
-
-
- /// Number of ldg_vector_t within each thread-tile
- ThreadLdgVectors = divide_assert::value,
-
- /// Extent of the thread tile in ldg_vector_t along L-axis
- ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
-
- /// Extent of the thread tile in ldg_vector_t along K-axis
- ThreadLdgVectorsK = divide_assert::value,
-
-
-
- /// Number of ldg_vector_t within each stripmine-tile
- StripmineLdgVectors = BlockThreads,
-
- /// Extent of the stripmine tile in ldg_vector_t along L-axis
- StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, StripmineLdgVectors),
-
- /// Extent of the stripmine tile in ldg_vector_t along K-axis
- StripmineLdgVectorsK = divide_assert::value,
-
-
-
- /// Alignment in dp_vector_t along L needed for committing prefetch
- AlignmentDpVectorsL = LdgVectorDpVectors,
- };
-
- /// Predicate bit vector
- typedef uint64_t predicate_mask_t;
-
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- static_assert(
- (ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
- "Predicate mask type does not contain enough bits for encoding load predicates");
-
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// Input pointer to matrix in ldg_vector_t
- ldg_vector_t *d_matrix_ldgvecs;
-
- /// Extent of the input matrix in ldg_vector_t along the L-axis
- int matrix_ldgvecs_l;
-
- /// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
- int block_end_ldgvec_k;
-
- /// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
- predicate_mask_t guard;
-
- /// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
- predicate_mask_t residue_guard;
-
- /// Iteration span in "whole-k" block-wide tiles
- int wholek_tiles_remaining;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_ldgvec_stride_k;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
- int matrix_ldgvec_stride_l;
-
- /// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- int2 block_thread_ldgvec_coords;
-
- /// Thread-wide tile of prefetch data
- ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
-
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_loader(
- value_t *d_matrix_items, ///< Input pointer to matrix in value_t
- int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
- int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
- int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
- int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
- :
- block_end_ldgvec_k(block_end_item_k),
- guard(0),
- residue_guard(0)
- {
- matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
- matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
- matrix_ldgvec_stride_l = matrix_items_stride_l;
-
- // ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- block_thread_ldgvec_coords = make_int2(
- threadIdx.x % BlockLdgVectorsL, // l-coordinate
- threadIdx.x / BlockLdgVectorsL); // k-coordinate
-
- // ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
- int2 matrix_block_ldgvec_coords = make_int2(
- matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
- matrix_block_item_coords.y); // k-coordinate
-
- // Iteration span in ldg_vector_t
- int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
-
-
-
- // ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
- int2 matrix_thread_ldgvec_coords = make_int2(
- block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
- block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
-
- // Iteration range in "whole-k" block-wide tiles
- wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
-
- // Extent of final residue-tile in ldg_vector_t along K-axis
- int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
-
- // Initialize I/O predicates
- if (AllowRaggedTiles)
- {
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
-
- // Whether block_ldgvec_coords.y is valid in the final residue tile
- predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
-
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- // Whether block_ldgvec_coords.x is valid any block-wide tile
- predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
-
- // Linear index of ldg_vector_t load
- int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
-
- // Set predicate guard bits
- guard |= (valid_l << ldgvec_idx);
- residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
- }
- }
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
-
- // Update the input pointer to be matrix_thread_ldgvec_coords
- this->d_matrix_ldgvecs =
- reinterpret_cast(d_matrix_items) +
- (matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
- (matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
- }
-
-
- //-------------------------------------------------------------------------
- // Loader API
- //-------------------------------------------------------------------------
-
- /**
- * Request the current block-wide tile
- */
- inline __device__
- void request()
- {
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- // Linear index of ldg_vector_t load
- int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
-
- // Unpack predicate guard
- predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
-
- if (!AllowRaggedTiles || valid)
- {
- // Perform load
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
- d_matrix_ldgvecs +
- (thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
- (thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
- }
- else
- {
- // Zero-initialize
- #pragma unroll
- for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
- }
- }
- }
- }
-
-
- /**
- * Advance the loader to the next block-wide tile in the K-axis
- */
- inline __device__
- void next()
- {
- d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
-
- if (AllowRaggedTiles)
- {
- --wholek_tiles_remaining;
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
- }
-
-
- /**
- * Commit the previously-requested block-wide tile to shared memory
- *
- * NB: To facilitate padding for avoiding shared memory bank conflicts, we
- * allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
- * tile width BlockDpVectorsL.
- */
- template
- inline __device__
- void commit(
- dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
- {
- static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
-
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
-
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].store(
- &scratch_tile[block_ldgvec_k][block_ldgvec_l * LdgVectorDpVectors]);
- }
- }
- }
-};
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/block_loader_congruous_idp4.h b/cutlass/gemm/block_loader_congruous_idp4.h
deleted file mode 100644
index 686da1db..00000000
--- a/cutlass/gemm/block_loader_congruous_idp4.h
+++ /dev/null
@@ -1,544 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Tile-loading abstraction for thread blocks
- */
-
-#include "../util/util.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * block_loader (CongruousCopy + idp4 specialization)
- ******************************************************************************/
-
-/**
- * \brief A three-phase data loading abstraction (prefetch, commit, and
- * advance) for iterating over ranges of block-wide matrix tiles.
- * (CongruousCopy + idp4 specialization)
- *
- * Each iteration sequence produces a KxL (height-by-width) block-wide tile of
- * value_t in shared memory. The layout of the shared block-wide tile is
- * a row-major (L-major) tiling of int32_t dp_vector_t, which are themselves
- * column-major (K-major) vectors of int8_t value_t. Its dimensions are:
- * K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
- * L = BlockDpVectorsL
- *
- * The data is copied from a corresponding tile of global matrix data whose
- * layout of value_t is also L-major. This constitutes a CongruousCopy between
- * the L-major global tile and the L-major shared tile.
- *
- * NB: The K-major value_t in shared dp_vector_t are imperfectly congruous
- * with the L-major value_t in global memory. As a result, the granularity
- * of data transfer is a "dp-square" of (DpVectorItems * DpVectorItems) values
- * that must be transposed from L-oriented dp_vector_t to K-oriented
- * dp_vector_t prior to commitment.
- *
- * NB: Consecutive threads within a block are mapped in L-major
- * fashion across a first-set of squares within their global tile. Successive
- * sets of squares are then strip-mined as necessary down the K-axis. These
- * discontiguous squares comprise the thread's "slice" of the block-wide tile.
- */
-template <
- int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
- int _BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
- int _BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
- int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
- bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
->
-struct block_loader<
- BlockThreads,
- _BlockDpVectorsK,
- _BlockDpVectorsL,
- int8_t, ///< Input matrix value type (idp4 specialization)
- LeadingDimAlignBytes,
- AllowRaggedTiles,
- int32_t, ///< Dot-product vector type along the K-axis (idp4 specialization)
- load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- /// Input matrix value type
- typedef int8_t value_t;
-
- /// Dot-product vector type along the K-axis
- typedef int32_t dp_vector_t;
-
- enum
- {
- /// Number of value_t in a dp_vector_t
- DpVectorItems = divide_assert::value,
-
- /// Number of dp_vector_t in a block-wide tile
- BlockDpVectors = _BlockDpVectorsK * _BlockDpVectorsL,
-
- /// Number of dp_vector_t in a thread-tile
- ThreadDpVectors = divide_assert::value,
-
- /// Number of dp_vector_t in a dp-square
- SquareDpVectors = DpVectorItems,
-
- /// Number of dp-square tiles in a thread-tile
- ThreadSquares = divide_assert::value,
-
- /// Extent of block-wide tile in transposed dp_vector_t along the K-axis (height)
- BlockTransDpVectorsK = _BlockDpVectorsK * DpVectorItems,
-
- /// Extent of block-wide tile in transposed dp_vector_t along the L-axis (height)
- BlockTransDpVectorsL = divide_assert<_BlockDpVectorsL, DpVectorItems>::value,
-
-
- };
-
- /// Load-from-global data movement type, coarsened by LeadingDimAlignBytes, capped by the
- /// smaller of either ThreadSquares or BlockTransDpVectorsL
- typedef io_vector<
- dp_vector_t,
- __NV_STD_MIN(ThreadSquares, BlockTransDpVectorsL),
- LeadingDimAlignBytes>
- ldg_vector_t;
-
- /// Store-to-shared data movement type equivalent to a dp-square
- typedef io_vector<
- dp_vector_t,
- SquareDpVectors>
- sts_vector_t;
-
- enum
- {
- /// Number of dp_vector_t per ldg_vector_t
- LdgVectorDpVectors = ldg_vector_t::VectorItems,
-
- /// Number of value_t per ldg_vector_t
- LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
-
-
-
- /// Total number of ldg_vector_t within each block-wide tile
- BlockLdgVectors = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along L-axis
- BlockLdgVectorsL = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along K-axis
- BlockLdgVectorsK = BlockTransDpVectorsK,
-
-
-
- /// Number of ldg_vector_t within each thread-tile
- ThreadLdgVectors = divide_assert::value,
-
- /// Extent of the thread tile in ldg_vector_t along L-axis
- ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
-
- /// Extent of the thread tile in ldg_vector_t along K-axis
- ThreadLdgVectorsK = divide_assert::value,
-
- /// Extent of the thread tile in dp-square tiles along K-axis
- ThreadSquaresK = divide_assert::value,
-
-
-
- /// Number of ldg_vector_t within each stripmine-tile
- StripmineLdgVectors = BlockThreads * SquareDpVectors,
-
- /// Extent of the stripmine tile in ldg_vector_t along L-axis
- StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, BlockThreads),
-
- /// Extent of the stripmine tile in ldg_vector_t along K-axis
- StripmineLdgVectorsK = divide_assert::value,
-
- /// Extent of the stripmine tile in dp-square tiles along K-axis
- StripmineSquaresK = divide_assert::value,
-
-
-
- /// Alignment in dp_vector_t along L needed for committing prefetch
- AlignmentDpVectorsL = LdgVectorDpVectors,
- };
-
- /// Predicate mask type
- typedef uint32_t predicate_mask_t;
-
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- static_assert((LeadingDimAlignBytes >= 4) && (LeadingDimAlignBytes % 4 == 0),
- "Alignment for matrix operands to IGEMM must be a multiple of 4 bytes.");
-
- static_assert(
- (ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
- "Predicate mask type does not contain enough bits for encoding load predicates");
-
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// Input pointer to matrix in ldg_vector_t
- ldg_vector_t *d_matrix_ldgvecs;
-
- /// Extent of the input matrix in ldg_vector_t along the L-axis
- int matrix_ldgvecs_l;
-
- /// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
- int block_end_ldgvec_k;
-
- /// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
- predicate_mask_t guard;
-
- /// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
- predicate_mask_t residue_guard;
-
- /// Iteration span in "whole-k" block-wide tiles
- int wholek_tiles_remaining;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_ldgvec_stride_k;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
- int matrix_ldgvec_stride_l;
-
- /// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- int2 block_thread_ldgvec_coords;
-
- /// Thread-wide tile of prefetch data
- ldg_vector_t thread_tile[ThreadSquaresK][SquareDpVectors][ThreadLdgVectorsL];
-
-
-
- //-------------------------------------------------------------------------
- // Utility methods
- //-------------------------------------------------------------------------
-
-
- /**
- * \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register. For SM2.0 or later.
- *
- * \par
- * The bytes in the two source registers \p a and \p b are numbered from 0 to 7:
- * {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes
- * {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within
- * the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0}
- *
- * \par Snippet
- * The code snippet below illustrates byte-permute.
- * \par
- * \code
- * #include
- *
- * __global__ void ExampleKernel(...)
- * {
- * int a = 0x03020100;
- * int b = 0x07060504;
- * int index = 0x00007531;
- *
- * int selected = prmt(a, b, index); // 0x07050301
- *
- * \endcode
- *
- */
- inline __device__
- int32_t prmt(int32_t a, int32_t b, unsigned int index)
- {
- int ret;
- asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
- return ret;
- }
-
-
- /**
- * Convert a "dp-square" from L-major to K-major
- */
- inline __device__
- void transpose_dp_square(dp_vector_t (&dp_square)[SquareDpVectors])
- {
- // Transpose dp_vector_t squares
- int32_t y = prmt(dp_square[0], dp_square[1], 0x00007362);
- int32_t w = prmt(dp_square[2], dp_square[3], 0x00007362);
- int32_t x = prmt(dp_square[0], dp_square[1], 0x00005140);
- int32_t z = prmt(dp_square[2], dp_square[3], 0x00005140);
-
- dp_square[0] = prmt(x, z, 0x00005410);
- dp_square[1] = prmt(x, z, 0x00007632);
- dp_square[2] = prmt(y, w, 0x00005410);
- dp_square[3] = prmt(y, w, 0x00007632);
- }
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_loader(
- value_t *d_matrix_items, ///< Input pointer to matrix in value_t
- int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
- int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
- int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
- int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
- :
- block_end_ldgvec_k(block_end_item_k),
- guard(0),
- residue_guard(0)
- {
- matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
- matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
- matrix_ldgvec_stride_l = matrix_items_stride_l;
-
- // ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- block_thread_ldgvec_coords = make_int2(
- threadIdx.x % BlockLdgVectorsL, // l-coordinate
- (threadIdx.x / BlockLdgVectorsL) * SquareDpVectors); // k-coordinate
-
- // ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
- int2 matrix_block_ldgvec_coords = make_int2(
- matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
- matrix_block_item_coords.y); // k-coordinate
-
- // Iteration span in ldg_vector_t
- int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
-
-
-
- // ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
- int2 matrix_thread_ldgvec_coords = make_int2(
- block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
- block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
-
- // Iteration range in "whole-k" block-wide tiles
- wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
-
- // Extent of final residue-tile in ldg_vector_t along K-axis
- int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
-
- // Initialize I/O predicates
- if (AllowRaggedTiles)
- {
- // Iterate through rows of squares in thread tile
- #pragma unroll
- for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
- {
- // Iterate through rows of dp_vector_t in each square
- #pragma unroll
- for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
- {
- // ldg_vector_t K-coordinate in block-wide tile (K-axis strip-mining of ldg_vector_t within block-tile)
- int block_ldgvec_k =
- block_thread_ldgvec_coords.y +
- (thread_square_k * StripmineLdgVectorsK) +
- square_dpvec;
-
- // Whether block_ldgvec_coords.y is valid in the final residue tile
- predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
-
- // L-axis strip-mining of block-tile
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- // ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- // Whether block_ldgvec_coords.x is valid any block-wide tile
- predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
-
- // Linear index of ldg_vector_t load
- int ldgvec_idx =
- (thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
- (square_dpvec * ThreadLdgVectorsL) +
- thread_ldgvec_l;
-
- // Set predicate guard bits
- guard |= (valid_l << ldgvec_idx);
- residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
- }
- }
- }
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
-
- // Update the input pointer to be matrix_thread_ldgvec_coords
- this->d_matrix_ldgvecs =
- reinterpret_cast(d_matrix_items) +
- (matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
- (matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
- }
-
-
- //-------------------------------------------------------------------------
- // Loader API
- //-------------------------------------------------------------------------
-
- /**
- * Request the current block-wide tile
- */
- inline __device__
- void request()
- {
- // Each thread iterates through the ldg_vector_t in its thread tile
-
- // Iterate through rows of squares in thread tile
- #pragma unroll
- for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
- {
- // Iterate through rows of dp_vector_t in each square
- #pragma unroll
- for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
- {
- // Iterate through ldg_vector_t in each row
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- // Linear index of ldg_vector_t load
- int ldgvec_idx =
- (thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
- (square_dpvec * ThreadLdgVectorsL) +
- thread_ldgvec_l;
-
- // Unpack predicate guard
- predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
-
- if (!AllowRaggedTiles || valid)
- {
- // Perform load
- thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].load(
- d_matrix_ldgvecs +
- (((thread_square_k * StripmineLdgVectorsK) + square_dpvec) * matrix_ldgvec_stride_k) +
- (thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
- }
- else
- {
- // Zero-initialize
- #pragma unroll
- for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
- thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[dpvec] = 0;
- }
- }
- }
- }
- }
-
-
- /**
- * Advance the loader to the next block-wide tile in the K-axis
- */
- inline __device__
- void next()
- {
- d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
-
- if (AllowRaggedTiles)
- {
- --wholek_tiles_remaining;
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
- }
-
-
- /**
- * Commit the previously-requested block-wide tile to shared memory
- *
- * NB: To facilitate padding for avoiding shared memory bank conflicts, we
- * allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
- * tile width BlockDpVectorsL.
- */
- template
- inline __device__
- void commit(
- dp_vector_t (&scratch_tile)[_BlockDpVectorsK][SmemDpVectorsL])
- {
- static_assert(SmemDpVectorsL >= _BlockDpVectorsL, "Row stride must be >= tile width.");
-
- // Square K-coordinate of thread tile in block-wide tile
- int block_thread_square_k = block_thread_ldgvec_coords.y / SquareDpVectors;
-
- // Iterate through rows of squares in thread tile
- #pragma unroll
- for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
- {
- // Square K-coordinate in block-wide tile (K-axis strip-mining of squares within block-tile)
- int block_square_k = block_thread_square_k + (thread_square_k * StripmineSquaresK);
-
- // Iterate through ldg_vector_t in each row
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- // ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- // Iterate through squares in each ldg_vector_t
- #pragma unroll
- for (int ldgvec_dpvec_l = 0; ldgvec_dpvec_l < LdgVectorDpVectors; ++ldgvec_dpvec_l)
- {
- // Square L-coordinate in block-wide tile (L-axis raking of square-slices within ldg_vector_t)
- int block_square_l = (block_ldgvec_l * LdgVectorDpVectors) + ldgvec_dpvec_l;
-
- // Assemble square of L-major dp_vector_t from stack of slices
- sts_vector_t square;
-
- // Iterate through rows of dp_vector_t in each square
- #pragma unroll
- for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
- {
- square.buff[square_dpvec] = thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[ldgvec_dpvec_l];
- }
-
- // Un-transpose square from L-major to K-major
- transpose_dp_square(square.buff);
-
- // Store dp-square
- square.store(&scratch_tile[block_square_k][block_square_l * SquareDpVectors]);
- }
- }
- }
- }
-};
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/block_loader_crosswise.h b/cutlass/gemm/block_loader_crosswise.h
deleted file mode 100644
index 13660e67..00000000
--- a/cutlass/gemm/block_loader_crosswise.h
+++ /dev/null
@@ -1,411 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Tile-loading abstraction for thread blocks
- */
-
-#include "../util/util.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * block_loader (CrosswiseCopy specialization)
- ******************************************************************************/
-
-/**
- * \brief A three-phase data loading abstraction (prefetch, commit, and
- * advance) for iterating over ranges of block-wide matrix tiles.
- * (CrosswiseCopy specialization)
- *
- * Each iteration sequence produces a KxL (height-by-width) block-wide tile of
- * value_t in shared memory. The layout of the shared block-wide tile is
- * a row-major (L-major) tiling of dp_vector_t items, which are themselves
- * column-major (K-major) vectors of value_t. Its dimensions are:
- * K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
- * L = BlockDpVectorsL
- *
- * The data is copied from a corresponding tile of global matrix data whose
- * layout of value_t is K-major. This constitutes a CrosswiseCopy between
- * the K-major global tile and the L-major shared tile.
- *
- * NB: The orientation of dp_vector_t components in shared memory is congruous
- * with the global matrix data, so we can use dp_vector_t as the minimum
- * granularity of data transfer without any intermediate {dis|re}assembly
- * of its value_t components. However, the global and shared memory layouts
- * of dp_vector_t items are cross-wise with respect to each other, so any
- * further LDG-vectorization of dp_vector_t data requires intermediate
- * disassembly into dp_vector_t components to be stored individually into
- * the shared tile.
- *
- * NB: Consecutive threads within a block are mapped in K-major
- * fashion down a first set of LDG-vectors of dp_vector_t within their global
- * tile. Successive sets of LDG-vectors are then strip-mined as necessary
- * across the L-axis. These discontiguous LDG-vectors comprise the thread's
- * "slice" of the block-wide tile.
- */
-template <
- int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
- int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
- int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
- typename value_t, ///< Input matrix value type
- int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
- bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
- typename dp_vector_t> ///< Dot-product vector type along the K-axis
-struct block_loader<
- BlockThreads,
- BlockDpVectorsK,
- BlockDpVectorsL,
- value_t,
- LeadingDimAlignBytes,
- AllowRaggedTiles,
- dp_vector_t,
- load_algorithm::CrosswiseCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- enum
- {
- /// Number of value_t in a dp_vector_t
- DpVectorItems = divide_assert::value,
-
- /// Number of dp_vector_t in a block-wide tile
- BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
-
- /// Number of dp_vector_t in a thread-tile
- ThreadDpVectors = divide_assert::value,
- };
-
- /// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
- /// smaller of either ThreadDpVectors or BlockDpVectorsK
- typedef io_vector<
- dp_vector_t,
- __NV_STD_MIN(ThreadDpVectors, BlockDpVectorsK),
- LeadingDimAlignBytes>
- ldg_vector_t;
-
- enum
- {
- /// Number of dp_vector_t per ldg_vector_t
- LdgVectorDpVectors = ldg_vector_t::VectorItems,
-
- /// Number of value_t per ldg_vector_t
- LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
-
-
-
- /// Total number of ldg_vector_t within each block-wide tile
- BlockLdgVectors = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along K-axis
- BlockLdgVectorsK = divide_assert::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along L-axis
- BlockLdgVectorsL = BlockDpVectorsL,
-
-
-
- /// Number of ldg_vector_t within each thread-tile
- ThreadLdgVectors = divide_assert::value,
-
- /// Extent of the thread tile in ldg_vector_t along K-axis
- ThreadLdgVectorsK = __NV_STD_MAX(1, (BlockLdgVectorsK / BlockThreads)),
-
- /// Extent of the thread tile in ldg_vector_t along L-axis
- ThreadLdgVectorsL = divide_assert::value,
-
-
-
- /// Number of ldg_vector_t within each stripmine-tile
- StripmineLdgVectors = BlockThreads,
-
- /// Extent of the stripmine tile in ldg_vector_t along K-axis
- StripmineLdgVectorsK = __NV_STD_MIN(BlockLdgVectorsK, StripmineLdgVectors),
-
- /// Extent of the stripmine tile in ldg_vector_t along L-axis
- StripmineLdgVectorsL = divide_assert::value,
-
-
-
- /// Alignment in dp_vector_t along L needed for committing prefetch
- AlignmentDpVectorsL = 1,
- };
-
- /// Predicate bit vector
- typedef uint64_t predicate_mask_t;
-
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- static_assert(
- (ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
- "Predicate mask type does not contain enough bits for encoding load predicates");
-
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// Input pointer to matrix in ldg_vector_t
- ldg_vector_t *d_matrix_ldgvecs;
-
- /// Extent of the input matrix in ldg_vector_t along the L-axis
- int matrix_ldgvecs_l;
-
- /// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
- int block_end_ldgvec_k;
-
- /// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
- predicate_mask_t guard;
-
- /// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
- predicate_mask_t residue_guard;
-
- /// Iteration span in "whole-k" block-wide tiles
- int wholek_tiles_remaining;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_ldgvec_stride_k;
-
- /// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
- int matrix_ldgvec_stride_l;
-
- /// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- int2 block_thread_ldgvec_coords;
-
- /// Thread-wide tile of prefetch data
- ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
-
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_loader(
- value_t *d_matrix_items, ///< Input pointer to matrix in value_t
- int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
- int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
- int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
- int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
- int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
- :
- block_end_ldgvec_k(block_end_item_k),
- guard(0),
- residue_guard(0)
- {
- matrix_ldgvecs_l = matrix_items_l;
- matrix_ldgvec_stride_k = matrix_items_stride_k;
- matrix_ldgvec_stride_l = (matrix_items_stride_l / LdgVectorItems);
-
- // ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
- block_thread_ldgvec_coords = make_int2(
- (threadIdx.x / BlockLdgVectorsK), // l-coordinate
- (threadIdx.x % BlockLdgVectorsK)); // k-coordinate
-
- // ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
- int2 matrix_block_ldgvec_coords = make_int2(
- matrix_block_item_coords.x, // l-coordinate
- matrix_block_item_coords.y / LdgVectorItems); // k-coordinate
-
- // Iteration span in ldg_vector_t
- int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y) / LdgVectorItems;
-
-
-
- // ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
- int2 matrix_thread_ldgvec_coords = make_int2(
- block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
- block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
-
- // Iteration range in "whole-k" block-wide tiles
- wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
-
- // Extent of final residue-tile in ldg_vector_t along K-axis
- int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
-
- // Initialize I/O predicates
- if (AllowRaggedTiles)
- {
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
-
- // Whether block_ldgvec_coords.y is valid in the final residue tile
- predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
-
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- // Whether block_ldgvec_coords.x is valid any block-wide tile
- predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
-
- // Linear index of ldg_vector_t load
- int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
-
- // Set predicate guard bits
- guard |= (valid_l << ldgvec_idx);
- residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
- }
- }
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
-
- // Update the input pointer to be matrix_thread_ldgvec_coords
- this->d_matrix_ldgvecs =
- reinterpret_cast(d_matrix_items) +
- (matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
- (matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
- }
-
-
- //-------------------------------------------------------------------------
- // Loader API
- //-------------------------------------------------------------------------
-
- /**
- * Request the current block-wide tile
- */
- inline __device__
- void request()
- {
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- // Linear index of ldg_vector_t load
- int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
-
- // Unpack predicate guard
- predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
-
- if (!AllowRaggedTiles || valid)
- {
- // Perform load
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
- d_matrix_ldgvecs +
- (thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
- (thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
- }
- else
- {
- // Zero-initialize
- #pragma unroll
- for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
- }
- }
- }
- }
-
-
- /**
- * Advance the loader to the next block-wide tile in the K-axis
- */
- inline __device__
- void next()
- {
- d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
-
- if (AllowRaggedTiles)
- {
- --wholek_tiles_remaining;
-
- // Promote residue-guard to primary-guard if no full tiles remain
- if (!wholek_tiles_remaining)
- {
- guard = residue_guard;
- }
- }
- }
-
-
- /**
- * Commit the previously-requested block-wide tile to shared memory
- *
- * NB: To facilitate padding for avoiding shared memory bank conflicts, we
- * allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
- * tile width BlockDpVectorsL.
- */
- template
- inline __device__
- void commit(
- dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
- {
- static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
-
- // Outer thread-tile ldg_vector_t iteration (K-axis)
- #pragma unroll
- for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
- {
- int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
-
- // Inner thread-tile ldg_vector_t iteration (L-axis)
- #pragma unroll
- for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
- {
- int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
-
- // Write column of dp_vector_t
- #pragma unroll
- for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
- {
- scratch_tile[(block_ldgvec_k * LdgVectorDpVectors) + dpvec][block_ldgvec_l] =
- thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec];
- }
- }
- }
- }
-};
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/block_loader_wmma.h b/cutlass/gemm/block_loader_wmma.h
deleted file mode 100644
index 5b586a1b..00000000
--- a/cutlass/gemm/block_loader_wmma.h
+++ /dev/null
@@ -1,322 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Tile-loading abstraction for thread blocks
- */
-
-#include "../util/util.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/**
- * block-wide tile loader supporting congruous mapping of data from source and
- * destination addressable storage. Typically, this will be used to load a
- * block-wide tile from global memory into shared memory.
- *
- * This enables the caller to specify MatrixAlignBytes guarantees of the input pointer
- * and performs memory operations on vectors. This increases the efficiency of
- * memory operations and reduces the number of guard predicates needed.
- *
- */
-template <
- bool congruous, ///< Indicates whether the "GEMM K" dimension refers to strided matrix dimension
- int BlockThreads, ///< Number of threads participating in the streaming operation
- int BlockItemsL, ///< Extent of block-wide tile in value_t along the L-axis (width)
- int BlockItemsK, ///< Extent of block-wide tile in value_t along the K-axis (height)
- typename value_t, ///< Input matrix value type
- int MatrixAlignBytes, ///< Byte alignment of input matrix
- bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
->
-struct block_loader_wmma
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- /// Predicate bit vector
- typedef uint64_t predicate_mask_t;
-
- /// Data movement type, coarsened by MatrixAlignBytes
- typedef io_vector<
- value_t,
- divide_assert::value,
- MatrixAlignBytes>
- ldg_vector_t;
-
- enum
- {
- /// Number of items per ldg_vector_t
- LdgVectorItems = ldg_vector_t::VectorItems,
-
- /// Total number of ldg_vector_t within the block-wide tile
- BlockLdgVectors = divide_assert<(BlockItemsL * BlockItemsK), LdgVectorItems>::value,
-
- /// Extent of the block-wide tile in ldg_vector_t along K-axis
- BlockLdgVectorsK = BlockItemsK,
-
- /// Extent of the block-wide tile in ldg_vector_t along L-axis
- BlockLdgVectorsL = divide_assert::value,
-
- /// Number of ldg_vector_t within each thread tile
- ThreadLdgVectors = divide_assert::value,
-
- /// Extent of the thread tile in ldg_vector_t along the L-axis
- ThreadLdgVectorsL = __NV_STD_MAX(1, BlockLdgVectorsL / BlockThreads),
-
- /// Block-wide strip-mining distance between ldg_vector_t along the K-axis
- BlockLdgVectorStrideK = __NV_STD_MAX(1, BlockThreads / BlockLdgVectorsL),
-
- /// Extent of the thread tile in ldg_vector_t along the K-axis
- ThreadLdgVectorsK = divide_assert::value,
- };
-
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- /// Define assertions
- static_assert(ThreadLdgVectorsL * ThreadLdgVectorsK == ThreadLdgVectors,
- "Number of vectors must be fully covered by the thread's 2D vector tile.");
-
- /// Predicate masks must be large enough to guard every vector load
- static_assert(sizeof(predicate_mask_t) * 8 >= ThreadLdgVectorsL * ThreadLdgVectorsK,
- "Predicate bit vector must be large enough to guard every vector load.");
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// pointer to tile in global memory
- const ldg_vector_t *ptr;
-
- /// stride of the matrix in the K-axis
- int matrix_values_stride_k;
-
- /// Guard predicate
- predicate_mask_t guard;
-
- /// Guard for the last request iteration
- predicate_mask_t residue_guard;
-
- /// Number of 'whole' request iterations before encountering the residue
- int request_iterations;
-
- /// fetch registers
- ldg_vector_t fetch[ThreadLdgVectors];
-
- /// Thread's base offset from the start of a block-wide tile
- int thread_offset_l;
-
- /// Thread's basae offset from the start of a block-wide tile
- int thread_offset_k;
-
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_loader_wmma(
- const value_t *d_matrix, ///< Pointer to input matrix
- int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
- int start_l, ///< Starting location in tile
- int dim_k, ///< Inner dimension of tile, used for computing guard predicates
- int _matrix_values_stride_k, ///< Stride of K-axis of atrix
- int start_k, ///< Tile's starting location
- int2 block_begin_item_coords) ///< Thread block's starting value_t coordinates (l, k) within the input matrix
- :
- ptr(reinterpret_cast(d_matrix)),
- matrix_values_stride_k(_matrix_values_stride_k / LdgVectorItems),
- guard(0),
- residue_guard(0)
- {
- // Compute block's starting coordinates in units of vectors
- int block_base_l = block_begin_item_coords.x / LdgVectorItems;
- int block_base_k = block_begin_item_coords.y;
-
- // Compute a thread tiling of the block-wide tile
- int tid = threadIdx.x;
- thread_offset_l = tid % BlockLdgVectorsL;
- thread_offset_k = tid / BlockLdgVectorsL;
-
- // Add the block and thread offsets to the source pointer
- ptr += (block_base_l + thread_offset_l) +
- (block_base_k + thread_offset_k) * matrix_values_stride_k;
-
- // When AllowRaggedTiles support is enabled, compute a bit vector of guard
- // predicates
- if (AllowRaggedTiles)
- {
- if (congruous)
- {
- request_iterations = (dim_k - start_k) / BlockItemsK;
- }
- else
- {
- request_iterations = (matrix_values_l - start_l) / BlockItemsL;
- }
-
- #pragma unroll
- for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
- {
- #pragma unroll
- for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
- {
- int item = l_idx + k_idx * ThreadLdgVectorsL;
-
- // Global vector L and K indices
- int vec_l = l_idx * BlockThreads;
- int vec_k = k_idx * BlockLdgVectorStrideK;
-
- predicate_mask_t pred;
- predicate_mask_t residue_pred;
-
- if (congruous)
- {
- pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < matrix_values_l) ? 1 : 0);
- residue_pred = ((block_base_k + thread_offset_k + vec_k < (dim_k % BlockItemsK)) ? 1 : 0);
- }
- else
- {
- pred = ((block_base_k + thread_offset_k + vec_k < dim_k) ? 1 : 0);
- residue_pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < (matrix_values_l % BlockItemsL)) ? 1 : 0);
- }
-
- // Update the guard and residue_guard word with predicate bits
- guard |= (pred << item);
- residue_guard |= (residue_pred << item);
- }
- }
-
- // If there are zero full request iterations, compute the intersection
- // with the residue guard.
- if (!request_iterations)
- {
- guard &= residue_guard;
- }
- }
- }
-
-
-
- /**
- * Request the current block-wide tile from source memory
- */
- inline __device__
- void request()
- {
- #pragma unroll
- for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
- {
- #pragma unroll
- for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
- {
- int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
- bool pred = !AllowRaggedTiles || (guard & (predicate_mask_t(1) << load_idx));
- if (pred)
- {
- fetch[load_idx].load(
- ptr +
- (k_idx * BlockLdgVectorStrideK * matrix_values_stride_k) + (l_idx * BlockThreads));
- }
- else
- {
- #pragma unroll
- for (int elem_idx = 0; elem_idx < LdgVectorItems; ++elem_idx)
- {
- fetch[load_idx].buff[elem_idx] = 0;
- }
- }
- }
- }
- }
-
-
- /// Advance to the next block-wide tile
- inline __device__
- void next()
- {
- if (congruous)
- {
- ptr += BlockItemsK * matrix_values_stride_k;
- }
- else
- {
- ptr += BlockLdgVectorsL;
- }
-
- // Track number of full iterations to intersect with the residue guard predicates.
- if (AllowRaggedTiles)
- {
- --request_iterations;
- if (!request_iterations)
- {
- guard &= residue_guard;
- }
- }
- }
-
-
- /// Commit the values to the scratch tile to destination memory.
- template
- inline __device__
- void commit(value_t *scratch_tile)
- {
- static_assert(SmemStride % LdgVectorItems == 0,
- "SMEM stride must be divisible by the size of vector loads");
-
- ldg_vector_t *smem_ptr = reinterpret_cast(scratch_tile);
- smem_ptr += thread_offset_l + thread_offset_k * SmemStride / LdgVectorItems;
-
- #pragma unroll
- for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
- {
- #pragma unroll
- for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
- {
- int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
-
- fetch[load_idx].store(smem_ptr +
- (k_idx * BlockLdgVectorStrideK * SmemStride / LdgVectorItems) +
- (l_idx * BlockThreads));
- }
- }
- }
-};
-
-
-} // namespace gemm
-} // namespace cutlass
-
diff --git a/cutlass/gemm/block_task.h b/cutlass/gemm/block_task.h
deleted file mode 100644
index 3940fb6e..00000000
--- a/cutlass/gemm/block_task.h
+++ /dev/null
@@ -1,677 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * A block-wide task abstraction for computing device-wide GEMM
- */
-
-#include
-
-#include "../util/util.h"
-
-#include "grid_raster.h"
-#include "block_loader.h"
-#include "k_split_control.h"
-#include "thread_accumulator.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * block_task_policy
- ******************************************************************************/
-
-/**
- * \brief Parameterizable tuning policy for \p block_task
- *
- * Once parameterized, \p block_task_policy provides the member constant
- * \p BlockThreads indicating to the required thread block size
- */
-template <
- int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
- int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
- int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
- int _ThreadItemsY, ///< Height in rows of a thread tile in C
- int _ThreadItemsX, ///< Width in columns of a thread tile in C
- bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
-struct block_task_policy
-{
- enum
- {
- /// Height in rows of a block-wide tile in matrix C
- BlockItemsY = _BlockItemsY,
-
- /// Width in columns of a block-wide tile in matrix C
- BlockItemsX = _BlockItemsX,
-
- /// Height in rows of a thread tile in C
- ThreadItemsY = _ThreadItemsY,
-
- /// Width in columns of a thread tile in C
- ThreadItemsX = _ThreadItemsX,
-
- /// Extent of block-wide A|B tiles in value_t along the K-axis
- BlockItemsK = _BlockItemsK,
-
- /// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- UseDoubleScratchTiles = _UseDoubleScratchTiles,
-
- /// Number of threads in each thread block (blockDim.x)
- BlockThreads = divide_assert<
- (BlockItemsY * BlockItemsX),
- (ThreadItemsY * ThreadItemsX)>::value,
- };
-
- /// Strategy for enumerating \p block_task within an input matrix
- static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
-};
-
-
-/******************************************************************************
- * block_task
- ******************************************************************************/
-
-/**
- * \brief A block-wide task abstraction for computing device-wide GEMM
- *
- * Each thread_block is assigned a unique tile of output matrix C to compute by
- * consuming the corresponding stripes of the input matrices A and B.
- */
-template <
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
- int LdgAlignA, ///< Alignment (in bytes) for A operand
- matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
- int LdgAlignB, ///< Alignment (in bytes) for B operand
- typename epilogue_op_t, ///< Epilogue operation applied to GEMM
- int LdgAlignC, ///< Alignment (in bytes) for C operand
- bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
->
-struct block_task
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- enum
- {
- /// Number of threads in each thread block (blockDim.x)
- BlockThreads = block_task_policy_t::BlockThreads,
-
- /// Extent of thread tile in value_t along M-axis
- ThreadItemsY = block_task_policy_t::ThreadItemsY,
-
- /// Extent of thread tile in value_t along N-axis
- ThreadItemsX = block_task_policy_t::ThreadItemsX,
- };
-
- /// Accumulator type
- typedef thread_accumulator<
- ThreadItemsY,
- ThreadItemsX,
- value_t,
- accum_t>
- thread_accumulator_t;
-
- /// Dot-product vector type along the K-axis (e.g, uchar4 when using IDP4A)
- typedef typename thread_accumulator_t::dp_vector_t dp_vector_t;
-
- enum
- {
- /// Whether this is a small, latency-bound tile
- IsSmallTile = (ThreadItemsY < 4) && (ThreadItemsX < 4),
-
- /// Number of value_t in dp_vector_t
- DpVectorItems = divide_assert::value,
-
- /// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
- BlockItemsY = block_task_policy_t::BlockItemsY,
-
- /// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
- BlockItemsX = block_task_policy_t::BlockItemsX,
-
- /// Extent of block-wide A|B tiles in value_t along the K-axis
- BlockItemsK = block_task_policy_t::BlockItemsK,
-
- /// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
-
- /// Extent of block-wide A|B tiles in dp_vector_t along the K-axis
- BlockDpVectorsK = divide_assert::value,
-
- /// Number of dp_vector_t along M-axis that can be read in a single LDS from the shared A-tile (up to 128b if more than one value_t)
- LdsVectorDpVectorsA = __NV_STD_MIN(
- ThreadItemsY,
- __NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
-
- /// Number of dp_vector_t along N-axis that can be read in a single LDS from the shared B-tile (up to 128b if more than one value_t)
- LdsVectorDpVectorsB = __NV_STD_MIN(
- ThreadItemsX,
- __NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
-
- /// Number of strip-mined LDS vector reads from shared A-tile
- ThreadLdsVectorsA = divide_assert::value,
-
- /// Number of strip-mined LDS vector reads from shared B-tile
- ThreadLdsVectorsB = divide_assert::value,
-
- /// Number of elements in one LDG/STG vector of C-tile
- ThreadLdgVectorSizeC = __NV_STD_MIN(LdgAlignC, 16) / (sizeof(accum_t)),
-
- /// Number of threads in warp
- WarpThreads = 32,
-
- /// Extent of warp in threads along the M-axis
- WarpThreadsY = (BlockItemsY > BlockItemsX) ? 8 : 4,
-
- /// Extent of warp in threads along the N-axis
- WarpThreadsX = divide_assert::value,
-
- /// Extent of warp-wide tile in items along the M-axis
- WarpItemsY = WarpThreadsY * ThreadItemsY,
-
- /// Extent of warp-wide tile in items along the N-axis
- WarpItemsX = WarpThreadsX * ThreadItemsX,
-
- /// Extent of block in warps along M-axis
- BlockWarpsY = divide_assert::value,
-
- /// Extent of block in warps along N-axis
- BlockWarpsX = divide_assert::value,
- };
-
- /// Load-from-shared data movement type for A-tile, coarsened by LdsVectorDpVectorsA
- typedef io_vector lds_vector_a_t;
-
- /// Load-from-shared data movement type for B-tile, coarsened by LdsVectorDpVectorsB
- typedef io_vector lds_vector_b_t;
-
- /// Thread block rasterization helper type
- typedef grid_raster<
- BlockItemsY,
- BlockItemsX,
- TransformA,
- TransformB,
- block_task_policy_t::RasterStrategy>
- grid_raster_t;
-
-
- /// Tile loader type for matrix A
- typedef block_loader<
- BlockThreads, // BlockThreads
- BlockDpVectorsK, // BlockDpVectorsK
- BlockItemsY, // BlockItemsL
- value_t, // value_t
- LdgAlignA, // MatrixAlignBytes
- AllowRaggedTiles, // AllowRaggedTiles
- dp_vector_t, // dp_vector_t
- (TransformA == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
- load_algorithm::CongruousCopy :
- load_algorithm::CrosswiseCopy>
- block_loader_a_t;
-
-
- /// Tile loader type for matrix B
- typedef block_loader<
- BlockThreads, // BlockThreads
- BlockDpVectorsK, // BlockDpVectorsK
- BlockItemsX, // BlockItemsL
- value_t, // value_t
- LdgAlignB, // MatrixAlignBytes
- AllowRaggedTiles, // AllowRaggedTiles
- dp_vector_t, // dp_vector_t
- (TransformB == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
- load_algorithm::CrosswiseCopy :
- load_algorithm::CongruousCopy>
- block_loader_b_t;
-
-
- enum
- {
- /// Number of value_t to pad the end of each row of the shared A-tile
- PadItemsA = (TransformA == matrix_transform_t::NonTranspose) ?
- __NV_STD_MAX(LdsVectorDpVectorsA, block_loader_a_t::AlignmentDpVectorsL) :
- LdsVectorDpVectorsA,
-
- /// Number of value_t to pad the end of each row of the shared B-tile
- PadItemsB = (TransformB == matrix_transform_t::NonTranspose) ?
- LdsVectorDpVectorsB :
- __NV_STD_MAX(LdsVectorDpVectorsB, block_loader_b_t::AlignmentDpVectorsL),
- };
-
-
- /// Shared memory layout for a prefetch page
- struct page_storage_t
- {
- /// Tile of A
- dp_vector_t __align__(16) block_a[BlockDpVectorsK][BlockItemsY + PadItemsA];
-
- /// Tile of B
- dp_vector_t __align__(16) block_b[BlockDpVectorsK][BlockItemsX + PadItemsB];
- };
-
-
- /// Shared memory layout for scratch storage
- struct scratch_storage_t
- {
- /// Prefetch pages
- page_storage_t pages[UseDoubleScratchTiles ? 2 : 1];
-
- /// Accumulator shared scratch
- typename thread_accumulator_t::scratch_storage_t accum_scratch;
- };
-
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- // Ensure we have at least two unrolled innermost loop iterations (one to prefetch
- // the next global tile and then one to prefetch the first strip of it from shared)
- static_assert ((BlockDpVectorsK >= 2), "BlockDpVectorsK must be >= 2.");
-
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// Scratch storage reference
- scratch_storage_t *scratch;
-
- /// Which page of scratch tiles we're currently reading from
- int page_idx;
-
- /// Pointer to matrix C
- accum_t *d_c;
-
- /// Epilogue operation applied to update matrix C
- epilogue_op_t epilogue_op;
-
- /// Matrix height in rows of trans_op(A) and C
- int dim_m;
-
- /// Matrix width in columns of trans_op(B) and C
- int dim_n;
-
- /// Control for inter-block k-splitting
- k_split_control k_split;
-
- /// Thread block's base value_t coordinates (m, n) in matrix C
- grid_raster_t grid_raster;
-
- /// Thread block's current coordinate (k) within A|B matrices
- int block_item_coords_k;
-
- /// Thread block's ending coordinate (k) within A|B matrices (one-past)
- int block_end_item_k;
-
- /// Warp's coordinates (x, y) in thread block
- int2 block_warp_coords;
-
- /// Thread's coordinates (x, y) in warp
- int2 warp_thread_coords;
-
- /// Thread's base item offset within strip of A tile
- int thread_strip_offset_a;
-
- /// Thread's base item offset within strip of B tile
- int thread_strip_offset_b;
-
- /// Thread's active-k/prefetch-k slices from shared A tile
- lds_vector_a_t local_slices_a[2][ThreadLdsVectorsA];
-
- /// Thread's active-k/prefetch-k slices from shared B tile
- lds_vector_b_t local_slices_b[2][ThreadLdsVectorsB];
-
- /// A tile loader
- block_loader_a_t loader_a;
-
- /// B tile loader
- block_loader_b_t loader_b;
-
- /// C tile accumulator
- thread_accumulator_t accumulator;
-
-
- //-------------------------------------------------------------------------
- // Coordinate system helpers
- //-------------------------------------------------------------------------
-
- /// Compute the warp's coordinates (x, y) in thread block
- inline __device__
- int2 warp_coords()
- {
- int warp_id = threadIdx.x / WarpThreads;
- return make_int2(
- warp_id % BlockWarpsX,
- warp_id / BlockWarpsX);
- }
-
-
- /// Compute the thread's lane-coordinates (x, y) in warp
- inline __device__
- int2 thread_coords()
- {
- int lane_id = threadIdx.x % WarpThreads;
-
- // Maxwell+ mapping of threads within a 2D warp for maximal LDS bandwidth
- return make_int2(
- lane_id / WarpThreadsY,
- lane_id % WarpThreadsY);
- }
-
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_task(
- scratch_storage_t *scratch,
- value_t *d_a,
- value_t *d_b,
- accum_t *d_c,
- epilogue_op_t epilogue_op,
- int dim_m,
- int dim_n,
- int dim_k,
- k_split_control k_split)
- :
- scratch(scratch),
- page_idx(0),
- d_c(d_c),
- epilogue_op(epilogue_op),
- dim_m(dim_m),
- dim_n(dim_n),
- k_split(k_split),
- block_item_coords_k(k_split.block_begin_item_k()),
- block_end_item_k(k_split.block_end_item_k(dim_k)),
- block_warp_coords(warp_coords()),
- warp_thread_coords(thread_coords()),
- thread_strip_offset_a((warp_thread_coords.y * LdsVectorDpVectorsA) + (block_warp_coords.y * WarpItemsY)),
- thread_strip_offset_b((warp_thread_coords.x * LdsVectorDpVectorsB) + (block_warp_coords.x * WarpItemsX)),
-
- loader_a(
- d_a, // d_matrix
- dim_m, // matrix_values_l
- (TransformA == matrix_transform_t::NonTranspose) ? dim_m : 1, // matrix_values_stride_k
- (TransformA == matrix_transform_t::NonTranspose) ? 1 : dim_k, // matrix_values_stride_l
- make_int2( // block_begin_item_coords
- grid_raster.block_item_coords.y,
- block_item_coords_k),
- block_end_item_k), // block_end_item_k
-
- loader_b(
- d_b, // d_matrix
- dim_n, // matrix_values_l
- (TransformB == matrix_transform_t::NonTranspose) ? 1 : dim_n, // matrix_values_stride_k
- (TransformB == matrix_transform_t::NonTranspose) ? dim_k : 1, // matrix_values_stride_l
- make_int2( // block_begin_item_coords
- grid_raster.block_item_coords.x,
- block_item_coords_k),
- block_end_item_k), // block_end_item_k
-
- accumulator(scratch->accum_scratch)
- {}
-
-
- //-------------------------------------------------------------------------
- // Prefetching utility methods
- //-------------------------------------------------------------------------
-
- /**
- * Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
- */
- inline __device__ void request_local_prefetch(
- lds_vector_a_t (&slice_a)[ThreadLdsVectorsA], ///< Slice from A
- lds_vector_b_t (&slice_b)[ThreadLdsVectorsB], ///< Slice from B
- int tile_offset_k)
- {
- // Load B strip
- for (int i = 0; i < ThreadLdsVectorsB; ++i)
- {
- slice_b[i].load(
- &scratch->pages[page_idx].block_b[tile_offset_k][thread_strip_offset_b + (i * WarpThreadsX * LdsVectorDpVectorsB)]);
- }
-
- // Load A strip
- for (int i = 0; i < ThreadLdsVectorsA; ++i)
- {
- slice_a[i].load(
- &scratch->pages[page_idx].block_a[tile_offset_k][thread_strip_offset_a + (i * WarpThreadsY * LdsVectorDpVectorsA)]);
- }
- }
-
-
- //-------------------------------------------------------------------------
- // Epilogue
- //-------------------------------------------------------------------------
-
- /**
- * Performs the GEMM epilogue:
- * - Applies the scalar multipliers and addends to the accumulators
- * - Write the result to the output matrix
- */
- __forceinline__ __device__
- void epilogue()
- {
- // Wait for predecessor thread block(s) to produce block-wide tile of
- // exclsuive partial-sums
- k_split.wait();
-
- // Configure epilogue as to whether the thread block is a secondary
- // accumulator in an inter-block k-splitting scheme
- if (k_split.is_secondary_accumulator())
- epilogue_op.set_secondary_accumulator();
-
- // Whether the addend from C needs loading
- bool must_init_addend = epilogue_op.must_init_addend();
-
- #pragma unroll
- for (int x = 0; x < ThreadItemsX; ++x)
- {
- #pragma unroll
- for (int y = 0; y < ThreadItemsY; y += LdsVectorDpVectorsA)
- {
- int thread_strip_b = x / LdsVectorDpVectorsB;
- int thread_strip_a = y / LdsVectorDpVectorsA;
-
- int thread_item_coords_tile_x = thread_strip_offset_b + (thread_strip_b * WarpThreadsX * LdsVectorDpVectorsB) + (x % LdsVectorDpVectorsB);
- int thread_item_coords_tile_y = thread_strip_offset_a + (thread_strip_a * WarpThreadsY * LdsVectorDpVectorsA) + (y % LdsVectorDpVectorsA);
-
- int c_idx = (grid_raster.block_item_coords.x + thread_item_coords_tile_x) * dim_m +
- grid_raster.block_item_coords.y + thread_item_coords_tile_y;
-
- accum_t *my_c = d_c + c_idx;
-
- #pragma unroll
- for (int i = 0; i < LdsVectorDpVectorsA; ++i)
- {
- accum_t c_slice = accum_t(0);
- accum_t *c_ptr = my_c + i;
-
- if ((grid_raster.block_item_coords.x + thread_item_coords_tile_x) < dim_n &&
- (grid_raster.block_item_coords.y + thread_item_coords_tile_y + i) < dim_m)
- {
- if (must_init_addend)
- {
- ldg_cg(c_slice, c_ptr);
- }
-
- c_slice = epilogue_op(accumulator.get(x, y + i), c_slice, c_idx + i);
-
- stg_cg(c_ptr, c_slice);
- }
- }
- }
- }
-
- // Signal k-split successor thread_block that we have produced our block-wide
- // tile of inclusive partial-sums
- k_split.signal();
- }
-
-
- //-------------------------------------------------------------------------
- // Tile consumption
- //-------------------------------------------------------------------------
-
- /**
- * Consume a tile of A and B each
- */
- template
- __forceinline__ __device__
- void consume_tile()
- {
- // Unroll BlockDpVectorsK iterations of outer-product accumulations
- #pragma unroll
- for (int tile_offset_k = 0; tile_offset_k < BlockDpVectorsK; tile_offset_k += 1)
- {
- // Last strip commits global prefetch for next tile
- if ((tile_offset_k == BlockDpVectorsK - 1) && DoGlobalPrefetch)
- {
- // If not using two pages of scratch tiles, protect the above prefetch loads from the committing writes below
- if (!UseDoubleScratchTiles)
- __syncthreads();
-
- // If using two pages of scratch tiles, switch to next page before writing
- if (UseDoubleScratchTiles)
- {
- page_idx = (page_idx ? 0 : 1);
- }
-
- // Commit global prefetch data to scratch page
- loader_a.commit(scratch->pages[page_idx].block_a);
- loader_b.commit(scratch->pages[page_idx].block_b);
-
- __syncthreads();
- }
-
- // Request local prefetch for next strip
- request_local_prefetch(
- local_slices_a[(tile_offset_k + 1) % 2],
- local_slices_b[(tile_offset_k + 1) % 2],
- (tile_offset_k + 1) % BlockDpVectorsK);
-
- // Request global prefetch for next tile on first strip
- if ((tile_offset_k == 0) && DoGlobalPrefetch)
- {
- loader_b.request();
- loader_b.next();
- loader_a.request();
- loader_a.next();
- }
-
- // Cast strip-mined loads to contiguous array of dp_vector_t
- typedef dp_vector_t thread_tile_a_t[ThreadLdsVectorsA * LdsVectorDpVectorsA];
- typedef dp_vector_t thread_tile_b_t[ThreadLdsVectorsB * LdsVectorDpVectorsB];
- thread_tile_a_t &thread_tile_a = reinterpret_cast(local_slices_a[(tile_offset_k) % 2]);
- thread_tile_b_t &thread_tile_b = reinterpret_cast(local_slices_b[(tile_offset_k) % 2]);
-
- // Accumulate this dp-stripe product
- accumulator.multiply_accumulate(thread_tile_a, thread_tile_b);
- }
- }
-
-
- //-------------------------------------------------------------------------
- // GEMM API
- //-------------------------------------------------------------------------
-
- /**
- * Compute GEMM
- */
- __forceinline__ __device__
- void run()
- {
- // Quit if the thread block is fully out-of-bounds
- if (grid_raster.is_block_oob(dim_m, dim_n))
- {
- asm volatile("exit;");
- }
-
- // Request global prefetch of first tile
- loader_a.request();
- loader_a.next();
- loader_b.request();
- loader_b.next();
-
- // Commit global prefetch of first tile to shared memory
- loader_a.commit(scratch->pages[page_idx].block_a);
- loader_b.commit(scratch->pages[page_idx].block_b);
-
- // Advance to next A,B tiles in K-axis
- block_item_coords_k += BlockItemsK;
-
- // Synchronize shared tiles and prepared accumulator
- __syncthreads();
-
- // Initialize thread's slice of accumulators
- accumulator.init();
-
- // Request first iteration of local prefetch strips
- request_local_prefetch(
- local_slices_a[0],
- local_slices_b[0],
- 0);
-
- //
- // Main loop
- //
-
- // Consume tiles in A and B along the K-axis (all but last tile)
- #pragma unroll 1
- while (block_item_coords_k < block_end_item_k)
- {
- consume_tile();
-
- // Advance to next A,B tiles in K-axis
- block_item_coords_k += BlockItemsK;
- }
-
- // Consume last tile
- consume_tile();
-
- //
- // Eplilogue
- //
-
- epilogue();
- }
-};
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/block_task_wmma.h b/cutlass/gemm/block_task_wmma.h
deleted file mode 100644
index 03d2d222..00000000
--- a/cutlass/gemm/block_task_wmma.h
+++ /dev/null
@@ -1,767 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-/**
- * \file
- * A block-wide task abstraction for computing device-wide GEMM
- */
-
-#pragma once
-
-// Compiler guard conditional to avoid compilation errors on versions of CUDA that
-// do not support the WMMA API.
-#if defined (WMMA)
-
-#include
-
-#include "../util/util.h"
-
-#include "grid_raster.h"
-#include "block_loader.h"
-#include "block_loader_wmma.h"
-#include "wmma_accumulator.h"
-
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * block_task_wmma_policy
- ******************************************************************************/
-
-/**
- * \brief Parameterizable tuning policy for block-wide WMMA GEMM tasks
- *
- * Once parameterized, \p block_task_policy provides the member constant
- * \p BlockThreads indicating to the required thread block size
- */
-template <
- int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
- int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
- int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
- int _WarpItemsY, ///< Height in rows of a Warp tile's accumulators
- int _WarpItemsX, ///< Width in columns of a Warp tile's accumulators
- int _WmmaItemsY, ///< Height in rows of a discrete WMMA block's accumulators
- int _WmmaItemsX, ///< Width in columns of a discrete WMMA block's accumulators
- int _WmmaItemsK, ///< Depth of each discrete WMMA block
- bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
-struct block_task_wmma_policy
-{
- /// Strategy for enumerating \p block_task within an input matrix
- static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
-
- enum
- {
- /// Height in rows of a block-wide tile in matrix C
- BlockItemsY = _BlockItemsY,
-
- /// Width in columns of a block-wide tile in matrix C
- BlockItemsX = _BlockItemsX,
-
- /// Extent of block-wide A|B tiles in value_t along the K-axis
- BlockItemsK = _BlockItemsK,
-
- /// Height in rows of a Warp tile's accumulators
- WarpItemsX = _WarpItemsX,
-
- /// Width in columns of a Warp tile's accumulators
- WarpItemsY = _WarpItemsY,
-
- /// Width in columns of a discrete WMMA block's accumulators
- WmmaItemsX = _WmmaItemsX,
-
- /// Height in rows of a discrete WMMA block's accumulators
- WmmaItemsY = _WmmaItemsY,
-
- /// Depth of each discrete WMMA block
- WmmaItemsK = _WmmaItemsK,
-
- /// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- UseDoubleScratchTiles = _UseDoubleScratchTiles,
-
-
- //
- // Derived quantities
- //
-
- /// Machine warp size
- WarpThreads = 32,
-
- /// Number of WMMA operations in the height dimension
- WmmaBlocksY = divide_assert::value,
-
- /// Number of WMMA operations in the height dimension
- WmmaBlocksX = divide_assert::value,
-
- /// Number of warps in each thread block
- BlockWarps = divide_assert::value,
-
- /// Number of threads in each thread block (blockDim.x)
- BlockThreads = BlockWarps * WarpThreads,
- };
-};
-
-
-/******************************************************************************
- * block_task_wmma
- ******************************************************************************/
-
-/**
- * \brief A block-wide task abstraction for computing device-wide GEMM
- *
- * Each thread_block is assigned a unique tile of output matrix C to compute by
- * consuming the corresponding stripes of the input matrices A and B.
- */
-template <
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
- int LdgAlignA, ///< Alignment (in bytes) for A operand
- matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
- int LdgAlignB, ///< Alignment (in bytes) for B operand
- typename epilogue_op_t, ///< Epilogue operation to update matrix C
- int LdgAlignC, ///< Alignment (in bytes) for C operand
- bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
->
-struct block_task_wmma
-{
- //-------------------------------------------------------------------------
- // Constants and types
- //-------------------------------------------------------------------------
-
- enum
- {
- /// Number of threads in each thread block (blockDim.x)
- BlockThreads = block_task_policy_t::BlockThreads,
-
- /// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
- BlockItemsY = block_task_policy_t::BlockItemsY,
-
- /// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
- BlockItemsX = block_task_policy_t::BlockItemsX,
-
- /// Extent of block-wide A|B tiles in value_t along the K-axis
- BlockItemsK = block_task_policy_t::BlockItemsK,
-
- /// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
- WarpItemsY = block_task_policy_t::WarpItemsY,
-
- /// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
- WarpItemsX = block_task_policy_t::WarpItemsX,
-
- /// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
- WmmaItemsY = block_task_policy_t::WmmaItemsY,
-
- /// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
- WmmaItemsX = block_task_policy_t::WmmaItemsX,
-
- /// Extent of warp-wide A|B-tiles in value_t along K-axis
- WmmaItemsK = block_task_policy_t::WmmaItemsK,
-
- /// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
- UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
-
- /// Number of threads in warp
- WarpThreads = block_task_policy_t::WarpThreads,
-
- /// Number of warps participating
- BlockWarps = block_task_policy_t::BlockWarps,
-
- /// Extent of block in warps along M-axis
- BlockWarpsY = divide_assert::value,
-
- /// Extent of block in warps along N-axis
- BlockWarpsX = divide_assert::value,
-
- /// Number of MMA unrolls
- WmmaUnrollCount = divide_assert::value,
-
- /// True if the A matrix layout is column major (K is the strided dimension)
- IsLayoutCongruousA = (TransformA == matrix_transform_t::NonTranspose),
-
- /// True if the B matrix layout is row mayor (K is the strided dimension)
- IsLayoutCongruousB = (TransformB == matrix_transform_t::Transpose),
-
- };
-
- /// WMMA may support unique types for A and B, so plan ahead for this
- typedef value_t value_a_t;
-
- /// WMMA may support unique types for A and B, so plan ahead for this
- typedef value_t value_b_t;
-
- /// WMMA accumulator type
- typedef wmma_accumulator<
- WarpItemsY,
- WarpItemsX,
- WmmaItemsY,
- WmmaItemsX,
- WmmaItemsK,
- value_a_t,
- value_b_t,
- accum_t,
- TransformA,
- TransformB>
- accumulator_t;
-
- /// Thread block rasterization helper type
- typedef grid_raster<
- BlockItemsY,
- BlockItemsX,
- TransformA,
- TransformB,
- block_task_policy_t::RasterStrategy>
- grid_raster_t;
-
- /// Tile loader type for matrix A
- typedef block_loader_wmma<
- IsLayoutCongruousA,
- BlockThreads,
- (IsLayoutCongruousA ? BlockItemsY : BlockItemsK),
- (IsLayoutCongruousA ? BlockItemsK : BlockItemsY),
- value_a_t,
- LdgAlignA,
- AllowRaggedTiles>
- block_loader_a_t;
-
- /// Tile loader type for matrix A
- typedef block_loader_wmma<
- IsLayoutCongruousB,
- BlockThreads,
- (IsLayoutCongruousB ? BlockItemsX : BlockItemsK),
- (IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
- value_b_t,
- LdgAlignB,
- AllowRaggedTiles>
- block_loader_b_t;
-
- /// Type alias for matrix A fragment type
- typedef typename accumulator_t::fragment_a_t fragment_a_t;
-
- /// Type alias for matrix B fragment type
- typedef typename accumulator_t::fragment_b_t fragment_b_t;
-
- enum
- {
- /// Number of fragments from A matrix
- WmmaBlocksY = accumulator_t::WmmaBlocksY,
-
- /// Number of fragments from B matrix
- WmmaBlocksX = accumulator_t::WmmaBlocksX,
-
- /// Number of value_t to pad the outer dimension of the shared A-tile
- PadItemsA = 16,
-
- /// Number of value_t to pad the outer dimension of the shared B-tile
- PadItemsB = 16,
-
- /// Leading dimension of A matrix tile
- LdmSmemA = (IsLayoutCongruousA ? BlockItemsY: BlockItemsK) + PadItemsA,
-
- /// Leading dimension of A matrix tile
- StridedSmemA = (IsLayoutCongruousA ? BlockItemsK : BlockItemsY ),
-
- /// Leading dimension of B matrix tile
- LdmSmemB = (IsLayoutCongruousB? BlockItemsX : BlockItemsK) + PadItemsB,
-
- StridedSmemB = (IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
- };
-
- /// Shared memory layout for a prefetch page
- struct page_storage_t
- {
- /// Tile of A
- value_a_t __align__(16) block_a[StridedSmemA][LdmSmemA];
-
- /// Tile of B
- value_b_t __align__(16) block_b[StridedSmemB][LdmSmemB];
- };
-
- /// Shared memory layout for scratch storage
- struct scratch_storage_t
- {
- union
- {
- /// Prefetch pages
- uninitialized pages[UseDoubleScratchTiles ? 2 : 1];
-
- /// Scratch storage for warps
- accum_t epilogue[BlockWarps][WmmaItemsX * WmmaItemsY];
- };
- };
-
- //-------------------------------------------------------------------------
- // Assert assumptions
- //-------------------------------------------------------------------------
-
- // Ensure we have at least two unrolled innermost loop iterations (one to prefetch
- // the next global tile and then one to prefetch the first strip of it from shared)
- static_assert ((BlockItemsK >= 2), "BlockItemsK must be >= 2.");
-
- //-------------------------------------------------------------------------
- // Members
- //-------------------------------------------------------------------------
-
- /// Scratch storage reference
- scratch_storage_t *scratch;
-
- /// Which page of scratch tiles we're currently reading from
- int page_idx;
-
- /// Pointer to matrix C
- accum_t *d_c;
-
- /// Epilogue operation applied to update matrix C
- epilogue_op_t epilogue_op;
-
- /// Matrix height in rows of trans_op(A) and C
- int dim_m;
-
- /// Matrix width in columns of trans_op(B) and C
- int dim_n;
-
- /// Control for inter-block k-splitting
- k_split_control k_split;
-
- /// Thread block's base value_t coordinates (m, n) in matrix C
- grid_raster_t grid_raster;
-
- /// Thread block's current coordinate (k) within A|B matrices
- int block_item_coords_k;
-
- /// Thread block's ending coordinate (k) within A|B matrices (one-past)
- int block_end_item_k;
-
- /// Warp's coordinates (x, y) in thread block
- int2 block_warp_item_coords;
-
- /// A tile loader
- block_loader_a_t loader_a;
-
- /// B tile loader
- block_loader_b_t loader_b;
-
- /// Thread's active-k/prefetch-k slices from shared A tile
- fragment_a_t local_slices_a[2][WmmaBlocksY];
-
- /// Thread's active-k/prefetch-k slices from shared B tile
- fragment_b_t local_slices_b[2][WmmaBlocksX];
-
- /// Accumulator tile
- accumulator_t accumulator;
-
-
- //-------------------------------------------------------------------------
- // Coordinate system helpers
- //-------------------------------------------------------------------------
-
- /// Compute the warp's item-coordinates (x, y) in thread block
- inline __device__
- int2 warp_item_coords()
- {
- int warp_id = threadIdx.x / WarpThreads;
-
- return make_int2(
- (warp_id / BlockWarpsY) * WarpItemsX,
- (warp_id % BlockWarpsY) * WarpItemsY);
- }
-
- /// Compute the thread block's base item-coordinates in matrix A
- inline __device__
- int2 a_block_item_coords()
- {
- if (TransformA == matrix_transform_t::NonTranspose)
- {
- return make_int2(grid_raster.block_item_coords.y, block_item_coords_k);
- }
- else
- {
- return make_int2(block_item_coords_k, grid_raster.block_item_coords.y);
- }
- }
-
- /// Compute the thread block's base item-coordinates in matrix B
- inline __device__
- int2 b_block_item_coords()
- {
- if (TransformB == matrix_transform_t::Transpose)
- {
- return make_int2(grid_raster.block_item_coords.x, block_item_coords_k);
- }
- else
- {
- return make_int2(block_item_coords_k, grid_raster.block_item_coords.x);
- }
- }
-
- //-------------------------------------------------------------------------
- // Constructor API
- //-------------------------------------------------------------------------
-
- /// Constructor
- inline __device__
- block_task_wmma(
- scratch_storage_t *scratch,
- value_t *d_a,
- value_t *d_b,
- accum_t *d_c,
- epilogue_op_t epilogue_op,
- int dim_m,
- int dim_n,
- int dim_k,
- k_split_control k_split)
- :
- scratch(scratch),
- page_idx(0),
- d_c(d_c),
- epilogue_op(epilogue_op),
- dim_m(dim_m),
- dim_n(dim_n),
- k_split(k_split),
- block_item_coords_k(k_split.block_begin_item_k()),
- block_end_item_k(k_split.block_end_item_k(dim_k)),
- block_warp_item_coords(warp_item_coords()),
-
- loader_a(
- reinterpret_cast(d_a),
- (IsLayoutCongruousA ? dim_m : block_end_item_k),
- (IsLayoutCongruousA ? 0 : block_item_coords_k),
- (IsLayoutCongruousA ? block_end_item_k : dim_m),
- (IsLayoutCongruousA ? dim_m : dim_k),
- (IsLayoutCongruousA ? block_item_coords_k : 0),
- a_block_item_coords()),
-
- loader_b(
- reinterpret_cast(d_b),
- (IsLayoutCongruousB ? dim_n : block_end_item_k),
- (IsLayoutCongruousB ? 0 : block_item_coords_k),
- (IsLayoutCongruousB ? block_end_item_k : dim_n),
- (IsLayoutCongruousB ? dim_n : dim_k),
- (IsLayoutCongruousB ? block_item_coords_k : 0),
- b_block_item_coords())
- {}
-
-
- //-------------------------------------------------------------------------
- // Prefetching utility methods
- //-------------------------------------------------------------------------
-
- /**
- * Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
- */
- inline __device__ void request_local_prefetch(
- fragment_a_t local_slices_a[WmmaBlocksY], ///< Slice from A
- fragment_b_t local_slices_b[WmmaBlocksX], ///< Slice from B
- int tile_offset_k)
- {
- value_b_t const *smem_A_base = &scratch->pages[page_idx].alias().block_a[0][0];
- value_b_t const *smem_B_base = &scratch->pages[page_idx].alias().block_b[0][0];
-
- int constexpr kstride_a = (IsLayoutCongruousA ? LdmSmemA : 1);
- int constexpr lstride_a = (IsLayoutCongruousA ? 1 : LdmSmemA);
-
- int constexpr kstride_b = (IsLayoutCongruousB ? LdmSmemB : 1);
- int constexpr lstride_b = (IsLayoutCongruousB ? 1 : LdmSmemB);
-
- // Load B strip
- #pragma unroll
- for (int i = 0; i < WmmaBlocksX; ++i)
- {
- value_b_t const *smem_B_ptr =
- &smem_B_base[tile_offset_k * kstride_b + (block_warp_item_coords.x + WmmaItemsX * i) * lstride_b];
-
- nvcuda::wmma::load_matrix_sync(local_slices_b[i], smem_B_ptr, LdmSmemB);
- }
-
- // Load A strip
- #pragma unroll
- for (int i = 0; i < WmmaBlocksY; ++i)
- {
- value_a_t const *smem_A_ptr =
- &smem_A_base[tile_offset_k * kstride_a + (block_warp_item_coords.y + WmmaItemsY * i) * lstride_a];
-
- nvcuda::wmma::load_matrix_sync(local_slices_a[i], smem_A_ptr, LdmSmemA);
- }
- }
-
-
- //-------------------------------------------------------------------------
- // Epilogue
- //-------------------------------------------------------------------------
-
- /**
- * Performs the GEMM epilogue:
- * - Applies the scalar multipliers and addends to the accumulators
- * - Write the result to the output matrix
- */
- inline __device__ void epilogue()
- {
- // Wait for predecessor thread block(s) to produce partial-sums
- k_split.wait();
-
- // Configure epilogue as to whether the thread block is a secondary
- // accumulator in an inter-block k-splitting scheme
- if (k_split.is_secondary_accumulator())
- epilogue_op.set_secondary_accumulator();
-
- // Whether or not the addend from C needs loading
- bool must_init_addend = epilogue_op.must_init_addend();
-
- int warp_base_x = grid_raster.block_item_coords.x + block_warp_item_coords.x;
- int warp_base_y = grid_raster.block_item_coords.y + block_warp_item_coords.y;
-
- int constexpr SmemStride = WmmaItemsY;
-
- int warp_id = threadIdx.x / 32;
-
- // Compute shape of one accumulator read/modify/write operation
- int constexpr ItemsY = (WmmaItemsY);
- int constexpr ItemsX = (32 / ItemsY);
- int constexpr IterationsX = WmmaItemsX / ItemsX;
-
- // Compute a rasterization of warp lanes across the WMMA tile.
- int lane_id = (threadIdx.x % 32);
- int lane_read_x = (lane_id / ItemsY);
- int lane_read_y = (lane_id % ItemsY);
-
- accum_t *smem_scratch = scratch->epilogue[warp_id];
- accum_t const *smem_read_ptr = smem_scratch + lane_read_y + lane_read_x * SmemStride;
-
- #pragma unroll
- for (int xb = 0; xb < WmmaBlocksX; ++xb)
- {
- #pragma unroll
- for (int yb = 0; yb < WmmaBlocksY; ++yb)
- {
- // Store accumulator tile to SMEM
- nvcuda::wmma::store_matrix_sync(
- smem_scratch,
- accumulator.accumulators[xb][yb],
- SmemStride,
- matrix_layout::kind);
-
- // Synchronize threads within the warp
- __syncthreads();
-
- // Compute lane coordinates so that each thread efficiently accesses SMEM.
- int c_x = (warp_base_x + (xb) * WmmaItemsX + lane_read_x);
- int c_y = (warp_base_y + (yb) * WmmaItemsY + lane_read_y);
-
- // Compute guard predicate by comparing against problem dimensions.
- bool pred = c_y < dim_m;
-
- // Compute output pointer from lane coordinates
- int c_index = c_x * dim_m + c_y;
- accum_t *c_ptr = reinterpret_cast(d_c) + c_x * dim_m + c_y;
-
- // Iterate over columns of output tile. Load from SMEM, compute epilogue operation,
- // and stream output to global memory
- #pragma unroll
- for (int item_x = 0; item_x < IterationsX; ++item_x)
- {
- accum_t accum = smem_read_ptr[item_x * ItemsX * SmemStride];
- accum_t c_element = 0;
-
- // Filter against problem dimensions as the warp iterates across the columns of
- // output.
- pred = (pred && ((c_x + item_x * ItemsX) < dim_n));
-
- if (must_init_addend && pred)
- {
- // NB: inline PTX to utilize strong operations for inter-block synchronization.
- // The following is equivalent to:
- //
- // c_element = c_ptr[0];
- asm volatile ("ld.global.cg.f32 %0, [%1];\n" : "=f"(c_element) : "l"(c_ptr));
- }
-
- c_element = epilogue_op(accum, c_element, c_index);
-
- if (pred)
- {
- // NB: inline PTX to utilize strong operations for inter-block synchronization.
- // The following is equivalent to:
- //
- // c_ptr[0] = c_element;
-
- asm volatile ("st.global.cg.f32 [%0], %1;\n" : : "l"(c_ptr), "f"(c_element));
- }
-
- // Increment output pointer
- c_ptr += dim_m * ItemsX;
- c_index += dim_m * ItemsX;
- }
- __syncthreads();
- }
- }
-
- // Signal k-split successor thread_block
- k_split.signal();
- }
-
- //-------------------------------------------------------------------------
- // Tile consumption
- //-------------------------------------------------------------------------
-
- /**
- * Consume a tile of A and B each
- */
- template
- inline __device__
- void consume_tile()
- {
- // Request global prefetch for next tile on first strip
- if (DoGlobalPrefetch)
- {
- loader_b.request();
- loader_b.next();
- loader_a.request();
- loader_a.next();
- }
-
- // Unroll BlockDpVectorsK iterations of outer-product accumulations
- #pragma unroll
- for (int iteration = 0; iteration < WmmaUnrollCount; ++iteration)
- {
- int tile_offset_k = iteration * WmmaItemsK;
-
- // Active load-from-shared index
- int active_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration) % 2);
-
- // Next load-from-shared index
- int next_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration + 1) % 2);
-
- // The last unrolled iteration commits the global fetches
- if ((iteration == WmmaUnrollCount - 1) && DoGlobalPrefetch)
- {
- // If not using two pages of scratch tiles, protect the above prefetch loads from
- // the committing writes below
- if (!UseDoubleScratchTiles)
- {
- __syncthreads();
- }
- else
- {
- page_idx = (page_idx ? 0 : 1);
- }
-
- // Commit global prefetch data to scratch page
- loader_a.template commit(&scratch->pages[page_idx].alias().block_a[0][0]);
- loader_b.template commit(&scratch->pages[page_idx].alias().block_b[0][0]);
-
- __syncthreads();
- }
-
- // Accumulate this dp-stripe product
- accumulator.multiply_accumulate(
- local_slices_a[active_lds_idx],
- local_slices_b[active_lds_idx]);
-
- // Request local prefetch for next strip
- request_local_prefetch(
- local_slices_a[next_lds_idx],
- local_slices_b[next_lds_idx],
- (tile_offset_k + WmmaItemsK) % BlockItemsK);
- }
- }
-
- //-------------------------------------------------------------------------
- // GEMM API
- //-------------------------------------------------------------------------
-
- /**
- * Compute GEMM
- */
- inline __device__
- void run()
- {
- // Quit if the thread block is fully out-of-bounds
- if (grid_raster.is_block_oob(dim_m, dim_n))
- {
- asm volatile("exit;");
- }
-
- // Request global prefetch of first tile
- loader_a.request();
- loader_a.next();
- loader_b.request();
- loader_b.next();
-
- // Commit global prefetch of first tile to shared memory
- loader_a.template commit(&scratch->pages[page_idx].alias().block_a[0][0]);
- loader_b.template commit(&scratch->pages[page_idx].alias().block_b[0][0]);
-
- // Advance to next A,B tiles in K-axis
- block_item_coords_k += BlockItemsK;
-
- // Synchronize shared tiles and prepared accumulator
- __syncthreads();
-
- // Initialize thread's slice of accumulators
- accumulator.init();
-
- // Request first iteration of local prefetch strips
- request_local_prefetch(
- local_slices_a[0],
- local_slices_b[0],
- 0);
-
- //
- // Main loop
- //
-
- // Consume tiles in A and B along the K-axis (all but last tile)
- #pragma unroll 1
- while (block_item_coords_k < block_end_item_k)
- {
- consume_tile();
-
- // Advance to next A,B tiles in K-axis
- block_item_coords_k += BlockItemsK;
- }
-
- consume_tile();
-
- //
- // Eplilogue
- //
-
- // prevent overwriting SMEM until all warps have finished loading data
- __syncthreads();
-
- // store accumulator tile to global memory
- epilogue();
- }
-};
-
-} // namespace gemm
-} // namespace cutlass
-
-#endif
diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h
new file mode 100644
index 00000000..12e1f579
--- /dev/null
+++ b/cutlass/gemm/clear_accumulators.h
@@ -0,0 +1,55 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines abstractions for efficiently clearing accumulator tiles.
+*/
+#pragma once
+
+#include
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct ClearAccumulators {
+ /// The shared storage.
+ struct SharedStorage {};
+
+ /// Ctor.
+ CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
+
+ /// Clear the fragment.
+ template
+ CUTLASS_DEVICE void clear(Fragment_& fragment) {
+ fragment.clear();
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/dgemm_traits.h b/cutlass/gemm/dgemm_traits.h
new file mode 100644
index 00000000..0bbc2210
--- /dev/null
+++ b/cutlass/gemm/dgemm_traits.h
@@ -0,0 +1,127 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines structural traits of double-precision GEMM.
+*/
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The tile size for the GEMM KxNxM.
+ typename OutputTile_,
+ /// The number of accumulators per thread.
+ typename AccumulatorsPerThread_,
+ /// The number of scalars per LDG for A.
+ int kScalarsPerLdgA_ = 1,
+ /// The number of scalars per LDG for B.
+ int kScalarsPerLdgB_ = 1>
+struct DgemmConfig
+ : public GemmConfig<
+ /// The scalar type for A.
+ double,
+ /// The scalar type for B.
+ double,
+ /// The scalar type for C.
+ double,
+ /// The scalar type for D.
+ double,
+ /// The tile size for the GEMM KxNxM.
+ OutputTile_,
+ /// The functor to do the math in the main loop.
+ ThreadMultiplyAdd, double, double, double>,
+ /// The number of scalars per LDG for A.
+ kScalarsPerLdgA_,
+ /// The number of scalars per STS for A.
+ kScalarsPerLdgA_,
+ /// The number of scalars per LDS for A.
+ 2,
+ /// The number of scalars per LDG for B.
+ kScalarsPerLdgB_,
+ /// The number of scalars per STS for B.
+ kScalarsPerLdgB_,
+ /// The number of scalars per LDS for B.
+ 2,
+ /// The number of scalars per LDG for C and STG for D.
+ 1,
+ /// The number of scalars per STS for D.
+ 2,
+ /// The number of scalars per LDS for D.
+ 1,
+ /// The number of stages in shared memory.
+ 2> {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The layout for A.
+ MatrixLayout::Kind kLayoutA_,
+ /// The layout for B.
+ MatrixLayout::Kind kLayoutB_,
+ /// The output tile.
+ typename OutputTile_ = Shape<8, 64, 128>,
+ /// The functor to use in the epilogue.
+ typename EpilogueFunctor_ = LinearScaling,
+ /// The number of accumulators per thread.
+ typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
+ /// The number of doubles loaded in one LDG for A.
+ int kScalarsPerLdgA_ = 1,
+ /// The number of doubles loaded in one LDG for B.
+ int kScalarsPerLdgB_ = 1,
+ /// The index.
+ typename Index_ = int,
+ /// The DGEMM config.
+ typename GemmConfig_ =
+ DgemmConfig,
+ /// The traits class for the epilogue.
+ typename GemmEpilogueTraits_ =
+ SimplifiedGemmEpilogueTraits >
+struct DgemmTraits : public SimplifiedGemmTraits<
+ // The layout for A.
+ kLayoutA_,
+ // The layout for B.
+ kLayoutB_,
+ // The config.
+ GemmConfig_,
+ // The epilogue.
+ GemmEpilogue,
+ // The index.
+ Index_> {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/dispatch.h b/cutlass/gemm/dispatch.h
deleted file mode 100644
index ae1819a6..00000000
--- a/cutlass/gemm/dispatch.h
+++ /dev/null
@@ -1,542 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * GEMM kernel entrypoint and dispatch stub
- */
-
-#include
-
-#include "../util/util.h"
-#include "block_task.h"
-#include "block_task_wmma.h"
-#include "grid_raster.h"
-#include "dispatch_policies.h"
-#include "k_split_control.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * param_pack
- ******************************************************************************/
-
-/**
- * Parameter-pack structure
- *
- * Kernel launch latency is reduced when kernel arguments are wrapped into
- * a single parameter
- */
-template <
- typename value_t,
- typename accum_t,
- typename epilogue_op_t>
-struct param_pack
-{
- int m; ///< Height in rows of op(A) and C
- int n; ///< Width in columns of op(B) and C
- int k; ///< Width in columns of op(A) and height in rows of op(B)
- k_split_control k_split; ///< Abstraction for controlling inter-block k-splitting
- value_t *d_a; ///< Pointer to matrix A array values
- value_t *d_b; ///< Pointer to matrix B array values
- accum_t *d_c; ///< Pointer to matrix C array values
- epilogue_op_t epilogue_op;
-
- param_pack(
- int m, ///< Height in rows of op(A) and C
- int n, ///< Width in columns of op(B) and C
- int k, ///< Width in columns of op(A) and height in rows of op(B)
- k_split_control k_split, ///< Abstraction for controlling inter-block k-splitting
- epilogue_op_t op, ///< Epilogue operation to update matrix C
- value_t *d_a, ///< Pointer to matrix A array values
- value_t *d_b, ///< Pointer to matrix B array values
- accum_t *d_c) ///< Pointer to matrix C array values
- :
- m(m),
- n(n),
- k(k),
- k_split(k_split),
- epilogue_op(op),
- d_a(d_a),
- d_b(d_b),
- d_c(d_c)
- {}
-
-};
-
-
-/******************************************************************************
- * Conditionally select the appropriate GEMM threadblock task
- ******************************************************************************/
-
-/// Conditional selection for block task
-template <
- math_operation_class_t math_op, ///<
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
- int LdgAlignA, ///< Alignment (in bytes) for A operand
- matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
- int LdgAlignB, ///< Alignment (in bytes) for B operand
- typename epilogue_op_t, ///< Epilogue operation applied to GEMM
- int LdgAlignC, ///< Alignment (in bytes) for C operand
- bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
->
-struct gemm_block_task;
-
-/// Scalar math operations
-template <
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
- int LdgAlignA, ///< Alignment (in bytes) for A operand
- matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
- int LdgAlignB, ///< Alignment (in bytes) for B operand
- typename epilogue_op_t, ///< Epilogue operation applied to GEMM
- int LdgAlignC, ///< Alignment (in bytes) for C operand
- bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
->
-struct gemm_block_task<
- math_operation_class_t::scalar,
- block_task_policy_t,
- value_t,
- accum_t,
- TransformA,
- LdgAlignA,
- TransformB,
- LdgAlignB,
- epilogue_op_t,
- LdgAlignC,
- AllowRaggedTiles
->
-{
- // Parameterize task type
- typedef block_task<
- block_task_policy_t,
- value_t,
- accum_t,
- TransformA,
- LdgAlignA,
- TransformB,
- LdgAlignB,
- epilogue_op_t,
- LdgAlignC,
- AllowRaggedTiles> type;
-};
-
-/// Matrix math operations
-template <
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
- int LdgAlignA, ///< Alignment (in bytes) for A operand
- matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
- int LdgAlignB, ///< Alignment (in bytes) for B operand
- typename epilogue_op_t, ///< Epilogue operation applied to GEMM
- int LdgAlignC, ///< Alignment (in bytes) for C operand
- bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
->
-struct gemm_block_task<
- math_operation_class_t::matrix,
- block_task_policy_t,
- value_t,
- accum_t,
- TransformA,
- LdgAlignA,
- TransformB,
- LdgAlignB,
- epilogue_op_t,
- LdgAlignC,
- AllowRaggedTiles>
-{
-
-#if defined(WMMA) // conditional compilation with WMMA headers
-
- // Parameterize task type
- typedef block_task_wmma<
- block_task_policy_t,
- value_t,
- accum_t,
- TransformA,
- LdgAlignA,
- TransformB,
- LdgAlignB,
- epilogue_op_t,
- LdgAlignC,
- AllowRaggedTiles> type;
-
-#endif
-};
-
-/******************************************************************************
- * GEMM kernel entrypoint
- ******************************************************************************/
-
-/**
- * GEMM kernel
- *
- * NB: Not sure why NVVM is doing stuff with "__launch_bounds__" instead of just
- * passing it along to PTXAS, but it is currently resulting in less optimal codegen
- */
-template <
- math_operation_class_t math_op, ///< Indicates which class of math operation to select
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- int LdgAlignA, ///< Alignment of A matrix elements in bytes
- matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
- int LdgAlignB, ///< Alignment of B matrix elements in bytes
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- typename epilogue_op_t, ///< Epilogue operation applied to update matrix C
- int LdgAlignC, ///< Alignment of C elements in bytes
- bool AllowRaggedTiles> ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
-__global__ void kernel(param_pack pack)
-{
- // Parameterize task type
- typedef typename gemm_block_task<
- math_op,
- block_task_policy_t,
- value_t,
- accum_t,
- TransformA,
- LdgAlignA,
- TransformB,
- LdgAlignB,
- epilogue_op_t,
- LdgAlignC,
- AllowRaggedTiles>::type block_task_t;
-
- // Declare statically-allocated shared storage
- __shared__ typename block_task_t::scratch_storage_t smem;
-
- // Construct and run the task
- block_task_t(
- &smem,
- pack.d_a,
- pack.d_b,
- pack.d_c,
- pack.epilogue_op,
- pack.m,
- pack.n,
- pack.k,
- pack.k_split).run();
-}
-
-
-/******************************************************************************
- * Launch configuration description returned to the caller
- ******************************************************************************/
-
-/// Return details about the launch configuration to the caller
-struct launch_configuration
-{
- //
- // Data members
- //
-
- /// cudaError_t resulting from grid launch
- cudaError_t result;
-
- /// Extent of a thread block's partition along the GEMM K-axis
- int split_k;
-
- /// Kernel grid extents in thread blocks
- dim3 grid;
-
- /// Thread block extents in threads
- dim3 block;
-
- //
- // Methods
- //
-
- /// Constructor
- launch_configuration():
- result(cudaSuccess),
- split_k(0),
- grid(0, 0, 0),
- block(0, 0, 0) {
-
- }
-
- /// Conversion from cudaError_t
- launch_configuration(cudaError_t result):
- result(result),
- split_k(1),
- grid(0, 0, 0),
- block(0, 0, 0) {
-
- }
-
- /// Launch configuration for Cutlass kernels
- launch_configuration(
- cudaError_t result,
- int split_k,
- dim3 grid,
- dim3 block
- ):
- result(result),
- split_k(split_k),
- grid(grid),
- block(block) {
-
- }
-};
-
-
-/******************************************************************************
- * Dispatch stub
- ******************************************************************************/
-
-/**
- * GEMM dispatch stub
- *
- * This function also serves as the autotuning entrypoint to evaluate different
- * tuning parameterizations of kernel.
- */
-template <
- math_operation_class_t math_op, ///< Indicates which class of math operation to select
- typename block_task_policy_t, ///< Parameterization of block_task_policy
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- int LdgAlignA, ///< Alignment of A matrix elements in bytes
- matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
- int LdgAlignB, ///< Alignment of B matrix elements in bytes
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- typename epilogue_op_t, ///< Epilogue operation
- int LdgAlignC, ///< Alignment of C matrix elements in bytes
- bool AllowRaggedTiles, ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
- typename kernel_ptr_t> ///< GEMM kernel function pointer type
-launch_configuration dispatch(
- kernel_ptr_t kernel_ptr, ///< GEMM kernel function pointer
- int m, ///< Height in rows of op(A) and C
- int n, ///< Width in columns of op(B) and C
- int k, ///< Width in columns of op(A) and height in rows of op(B)
- epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
- value_t *d_a, ///< Device pointer to matrix A array values
- value_t *d_b, ///< Device pointer to matrix B array values
- accum_t *d_c, ///< Device pointer to matrix C array values
- cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream0.
- bool debug_synchronous = true) ///< Whether or not to synchronize the stream after every kernel launch
- /// to check for errors. Also causes launch configurations to be printed
- /// to the console if DEBUG is defined. Default is \p false.
-{
- // Thread block rasterization type
- typedef grid_raster<
- block_task_policy_t::BlockItemsY,
- block_task_policy_t::BlockItemsX,
- TransformA,
- TransformB,
- block_task_policy_t::RasterStrategy>
- grid_raster_t;
-
- launch_configuration config;
-
- // Compute block dims
- config.block = dim3(block_task_policy_t::BlockThreads);
-
- // Compute shared memory
- int dynamic_smem_bytes = 0;
-
- // Compute occupancy
- int max_sm_occupancy;
- if (CUDA_PERROR_DEBUG(config.result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
- &max_sm_occupancy,
- kernel_ptr,
- config.block.x * config.block.y,
- dynamic_smem_bytes)))
- {
- return config;
- }
-
- // Compute grid extents
- config.grid = grid_raster_t::grid_dims(m, n);
-
- // Get SM count
- int sm_count;
- if (CUDA_PERROR_DEBUG(config.result = get_sm_count(sm_count)))
- return config;
-
- // Get k-split flag storage (TODO: make a pool)
- int *d_flags;
- if (CUDA_PERROR_DEBUG(config.result = cudaGetSymbolAddress((void**) &d_flags, d_flags_split_k)))
- return config;
-
- // Construct k-split coordinator
- k_split_control k_split(
- d_flags,
- sm_count,
- max_sm_occupancy,
- k,
- block_task_policy_t::BlockItemsK,
- config.block,
- config.grid); // in,out
-
- config.split_k = k_split.split_k;
-
- // Log kernel configuration
- if (debug_synchronous)
- {
- // Compute tiling efficiency
- float block_tiling_efficiency = float(block_task_policy_t::BlockItemsY * block_task_policy_t::BlockItemsX) /
- float(block_task_policy_t::BlockItemsY + block_task_policy_t::BlockItemsX);
-
- float tiling_efficiency = block_tiling_efficiency;
-
- float wave_efficiency = k_split.get_wave_efficiency(
- sm_count, max_sm_occupancy, config.block, config.grid);
-
- CUDA_LOG_DEBUG("Final wave_efficiency %.4f, tiling_efficiency %.4f\n",
- wave_efficiency, tiling_efficiency);
-
- CUDA_LOG_DEBUG("Invoking kernel<<<(%d, %d, %d), (%d.y,%d.x), %d, %lld>>>(), %d SM occupancy, %d split_k\n",
- config.grid.x, config.grid.y, config.grid.z,
- config.block.y, config.block.x,
- dynamic_smem_bytes,
- (long long) stream,
- max_sm_occupancy,
- k_split.split_k);
- }
-
- // Construct parameter-pack
- param_pack pack(
- m,
- n,
- k,
- k_split,
- epilogue_op,
- d_a,
- d_b,
- d_c);
-
- // Prepare k-split coordinator
- if (CUDA_PERROR_DEBUG(config.result = k_split.prepare(stream, debug_synchronous)))
- {
- return config;
- }
-
- // Invoke kernel
- kernel_ptr<<< config.grid, config.block, dynamic_smem_bytes, stream >>>(pack);
-
- // Check for failure to launch
- if (CUDA_PERROR_DEBUG(config.result = cudaPeekAtLastError()))
- return config;
-
- // Sync the stream if specified to flush runtime errors
- if (debug_synchronous && (CUDA_PERROR_DEBUG(config.result = cudaStreamSynchronize(stream))))
- return config;
-
- return config;
-}
-
-
-/******************************************************************************
- * GEMM
- ******************************************************************************/
-
-/**
- * Computes gemm on device matrices
- */
-template <
- tiling_strategy::kind_t TilingStrategy, ///< Tile-sizing classification
- math_operation_class_t math_op, ///< Indicates which class of math operation to select
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- int LdgAlignA, ///< Alignment (in bytes) of A operand
- matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
- int LdgAlignB, ///< Alignment (in bytes) of B operand
- typename value_t, ///< Multiplicand value type (matrices A and B)
- typename accum_t, ///< Accumulator value type (matrix C and scalars)
- typename epilogue_op_t, ///< Epilogue operation to update matrix C
- int LdgAlignC> ///< Alignment (in bytes) of C operand
-launch_configuration device_gemm(
- int m, ///< Height in rows of op(A) and C
- int n, ///< Width in columns of op(B) and C
- int k, ///< Width in columns of op(A) and height in rows of op(B)
- epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
- value_t *d_a, ///< Device pointer to matrix A array values
- value_t *d_b, ///< Device pointer to matrix B array values
- accum_t *d_c, ///< Device pointer to matrix C array values
- cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream0.
- bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to
- /// check for errors. Also causes launch configurations to be printed to
- /// the console if DEBUG is defined. Default is \p false.
-{
- // Parameterize an task policy type
- // (TODO: use a policy dispatch mechanism based upon SM version)
- typedef gemm_policy block_task_policy_t;
-
- // AllowRaggedTiles-tile check
- if ((m % block_task_policy_t::BlockItemsY != 0) ||
- (n % block_task_policy_t::BlockItemsX != 0) ||
- (k % block_task_policy_t::BlockItemsK != 0))
- {
- // Needs ragged tile-handling
- static const bool AllowRaggedTiles = true;
-
- return dispatch(
- kernel,
- m,
- n,
- k,
- epilogue_op,
- d_a,
- d_b,
- d_c,
- stream,
- debug_synchronous);
- }
- else
- {
- // Does not need ragged tile-handling
- static const bool AllowRaggedTiles = false;
-
- return dispatch(
- kernel,
- m,
- n,
- k,
- epilogue_op,
- d_a,
- d_b,
- d_c,
- stream,
- debug_synchronous);
- }
-
-
-}
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/dispatch_policies.h b/cutlass/gemm/dispatch_policies.h
deleted file mode 100644
index 298f7be5..00000000
--- a/cutlass/gemm/dispatch_policies.h
+++ /dev/null
@@ -1,661 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Architecture-specific GEMM block_task policies
- */
-
-#include
-
-#include "../util/util.h"
-#include "block_task.h"
-#include "grid_raster.h"
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * tiling_strategy
- ******************************************************************************/
-
-/**
- * Enumeration of tile-sizing granularities
- */
-struct tiling_strategy : printable_t
-{
- /// \brief Enumerants
- enum kind_t
- {
- Unknown,
- Small,
- Medium,
- Large,
- Tall,
- Wide,
- Huge,
- };
-
- /// Enumerant value
- kind_t kind;
-
- /// Default constructor
- tiling_strategy() : kind(Unknown) {}
-
- /// Copy constructor
- tiling_strategy(const kind_t &other_kind) : kind(other_kind) {}
-
- /// Cast to kind_t
- operator kind_t() const { return kind; }
-
- /// Returns the instance as a string
- __host__ __device__ inline
- char const* to_string() const
- {
- switch (kind)
- {
- case Small: return "small";
- case Medium: return "medium";
- case Large: return "large";
- case Tall: return "tall";
- case Wide: return "wide";
- case Huge: return "huge";
- case Unknown:
- default: return "unknown";
- }
- }
-
- /// Insert the formatted instance into the output stream
- void print(std::ostream& out) const { out << to_string(); }
-};
-
-
-/******************************************************************************
- * GEMM
- ******************************************************************************/
-
-/**
- * GEMM task policy specialization for sgemm
- */
-template <
- typename value_t,
- typename accum_t,
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
- tiling_strategy::kind_t TilingStrategy> ///< Tile-sizing classification
-struct gemm_policy;
-
-
-/******************************************************************************
- * SGEMM
- ******************************************************************************/
-
-/**
- * GEMM task policy specialization for small sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 16, // _BlockItemsY
- 16, // _BlockItemsX
- 16, // _BlockItemsK
- 2, // _ThreadItemsY
- 2, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/**
- * GEMM task policy specialization for medium sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 8, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 64, // _BlockItemsY
- 64, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for tall sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 128, // _BlockItemsY
- 32, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for wide sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 32, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 4, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for huge sgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 128, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/******************************************************************************
- * DGEMM
- ******************************************************************************/
-
-/**
- * GEMM task policy specialization for small dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 16, // _BlockItemsY
- 16, // _BlockItemsX
- 16, // _BlockItemsK
- 2, // _ThreadItemsY
- 2, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/**
- * GEMM task policy specialization for medium dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 16, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 64, // _BlockItemsY
- 64, // _BlockItemsX
- 8, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for tall dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 128, // _BlockItemsY
- 32, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for wide dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 32, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 4, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for huge dgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 64, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/******************************************************************************
- * HGEMM
- ******************************************************************************/
-
-/**
- * GEMM task policy specialization for small hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Small> :
- block_task_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 8, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for medium hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Medium> :
- block_task_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 16, // _BlockItemsK
- 8, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Large> :
- block_task_policy<
- 64, // _BlockItemsY
- 64, // _BlockItemsX
- 8, // _BlockItemsK
- 16, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for tall hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Tall> :
- block_task_policy<
- 128, // _BlockItemsY
- 32, // _BlockItemsX
- 8, // _BlockItemsK
- 16, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/**
- * GEMM task policy specialization for wide hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Wide> :
- block_task_policy<
- 32, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 8, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for huge hgemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Huge> :
- block_task_policy<
- 128, // _BlockItemsY
- 128, // _BlockItemsX
- 8, // _BlockItemsK
- 16, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/******************************************************************************
- * IGEMM
- ******************************************************************************/
-
-/**
- * GEMM task policy specialization for small igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 16, // _BlockItemsY
- 32, // _BlockItemsX
- 32, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/**
- * GEMM task policy specialization for medium igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 32, // _BlockItemsK
- 4, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 64, // _BlockItemsY
- 64, // _BlockItemsX
- 32, // _BlockItemsK
- 8, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 128, // _BlockItemsY
- 64, // _BlockItemsX
- 64, // _BlockItemsK
- 8, // _ThreadItemsY
- 4, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for large igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 64, // _BlockItemsY
- 128, // _BlockItemsX
- 64, // _BlockItemsK
- 4, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-/**
- * GEMM task policy specialization for huge igemm
- */
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- block_task_policy<
- 128, // _BlockItemsY
- 128, // _BlockItemsX
- 32, // _BlockItemsK
- 8, // _ThreadItemsY
- 8, // _ThreadItemsX
- false, // _UseDoubleScratchTiles
- grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-
-/******************************************************************************
- * WMMA GEMM
- ******************************************************************************/
-
-// WMMA is a preview feature in CUDA. Conditionally enable wmma_gemm policies.
-#if defined(WMMA)
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- gemm::block_task_wmma_policy<
- 16, // _BlockItemsY
- 16, // _BlockItemsX
- 16, // _BlockItemsK
- 16, // _WarpItemsY
- 16, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy :
- gemm::block_task_wmma_policy<
- 32, // _BlockItemsY
- 32, // _BlockItemsX
- 32, // _BlockItemsK
- 32, // _WarpItemsY
- 32, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Large> :
- gemm::block_task_wmma_policy<
- 64, // _BlockItemsY
- 64, // _BlockItemsX
- 32, // _BlockItemsK
- 32, // _WarpItemsY
- 64, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Tall> :
- gemm::block_task_wmma_policy<
- 128, // _BlockItemsY
- 64, // _BlockItemsX
- 64, // _BlockItemsK
- 32, // _WarpItemsY
- 64, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Wide> :
- gemm::block_task_wmma_policy<
- 64, // _BlockItemsY
- 128, // _BlockItemsX
- 64, // _BlockItemsK
- 32, // _WarpItemsY
- 64, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-template <
- matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
- matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
-struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Huge> :
- gemm::block_task_wmma_policy<
- 128, // _BlockItemsY
- 128, // _BlockItemsX
- 64, // _BlockItemsK
- 32, // _WarpItemsY
- 64, // _WarpItemsX
- 16, // _WmmaItemsY
- 16, // _WmmaItemsX
- 16, // _WmmaItemsK
- false, // _UseDoubleScratchTiles
- gemm::grid_raster_strategy::Default> // _RasterStrategy
-{};
-
-#endif
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/dp_accummulate.h b/cutlass/gemm/dp_accummulate.h
deleted file mode 100644
index 6b5d4b6c..00000000
--- a/cutlass/gemm/dp_accummulate.h
+++ /dev/null
@@ -1,223 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Abstraction for exposing architecture-specific "dot-product-accumulate"
- * ISA operations
- */
-
-#include
-
-#include "../util/util.h"
-
-
-namespace cutlass {
-namespace gemm {
-
-
-/******************************************************************************
- * dp_accummulate
- ******************************************************************************/
-
-
-/**
- * \brief Abstraction for exposing architecture-specific "dot-product-accumulate"
- * ISA operations
- *
- * Given two K-component vectors a and b having type value_t[K] and an addend c
- * of type accum_t, the "dot-product-accumulate" of type accum_t is computed
- * as d = x[0]*y[0] + x[1]*y[1] + ... + x[K-1]*y[K-1] + c.
- *
- * We use the notation "dpK" to connote a K-component dot-product-accumulate.
- * For example, "dp1" is a simple multiply-add.
- *
- * For given pairing of value_t and accum_t types, the corresponding
- * dp_accummulate class will:
- *
- * - Define the member-type dp_vector_t as the appropriate K-component vector
- * type needed to leverage architecture-specific "dot-product accumulate"
- * ISA operations.
- * - Implement the corresponding dot-product operation between two dp_vector_t
- * inputs a and b.
- *
- */
-template <
- typename value_t, ///< Component value type
- typename accum_t> ///< Accumulator value type
-struct dp_accummulate;
-
-
-
-/// Default "dp1" dot-product-accumulate traits specialization for value_t->accum_t
-template <
- typename value_t, ///< Component value type
- typename accum_t> ///< Accumulator value type
-struct dp_accummulate
-{
- /// Single-component "dp1" dot-product vector type
- typedef value_t dp_vector_t;
-
-
- /// Compute "dp1" float->float
- inline __device__
- static void mad(
- float &d,
- const float &a,
- const float &b,
- const float &c)
- {
- asm volatile ( "fma.rn.f32 %0, %1, %2, %3;\n"
- : "=f"(d) : "f"(a), "f"(b), "f"(c));
- }
-
-
- /// Compute "dp1" double->double
- inline __device__
- static void mad(
- double &d,
- const double &a,
- const double &b,
- const double &c)
- {
- asm volatile ("fma.rn.f64 %0, %1, %2, %3;\n"
- : "=d"(d) : "d"(a), "d"(b), "d"(c));
- }
-
-
- /// Compute "dp1" int16_t->int32_t
- inline __device__
- static void mad(
- int32_t &d,
- const int16_t &a,
- const int16_t &b,
- const int32_t &c)
- {
- asm volatile ("mad.wide.s16 %0, %1, %2, %3;\n"
- : "=r"(d) : "h"(a), "h"(b), "r"(c));
- }
-
-
- /// Compute "dp1" uint16_t->uint32_t
- inline __device__
- static void mad(
- uint32_t &d,
- const uint16_t &a,
- const uint16_t &b,
- const uint32_t &c)
- {
- asm volatile ("mad.wide.u16 %0, %1, %2, %3;\n"
- : "=r"(d) : "h"(a), "h"(b), "r"(c));
- }
-
-
- /// Compute "dp1" int32_t->int32_t
- inline __device__
- static void mad(
- int32_t &d,
- const int32_t &a,
- const int32_t &b,
- const int32_t &c)
- {
- asm volatile ("mad.lo.s32 %0, %1, %2, %3;\n"
- : "=r"(d) : "r"(a), "r"(b), "r"(c));
- }
-
-
- /// Compute "dp1" uint32_t->uint32_t
- inline __device__
- static void mad(
- uint32_t &d,
- const uint32_t &a,
- const uint32_t &b,
- const uint32_t &c)
- {
- asm volatile ("mad.lo.u32 %0, %1, %2, %3;\n"
- : "=r"(d) : "r"(a), "r"(b), "r"(c));
- }
-
-};
-
-
-
-#if (CUTLASS_ARCH >= 610) // Specializations only enabled for Pascal SM610+
-
-
-/// "dp4" dot-product-accumulate traits specialization for int8_t->int32_t
-template <>
-struct dp_accummulate<
- int8_t, ///< Component value type
- int32_t> ///< Accumulator value type
-{
- /// Four-component signed "idp4"
- typedef int32_t dp_vector_t;
-
- /// Compute "dp4" int16_t->int32_t
- inline __device__
- static void mad(
- int32_t &d,
- const int32_t &a,
- const int32_t &b,
- const int32_t &c)
- {
- asm volatile ( "dp4a.s32.s32 %0, %1, %2, %3;\n"
- : "=r"(d) : "r"(a), "r"(b), "r"(c));
- }
-};
-
-
-/// "dp4" dot-product-accumulate traits specialization for uint8_t->uint32_t
-template <>
-struct dp_accummulate<
- uint8_t, ///< Component value type
- uint32_t> ///< Accumulator value type
-{
- /// Four-component unsigned "idp4"
- typedef uint32_t dp_vector_t;
-
- /// Compute "dp4" uint16_t->uint32_t
- inline __device__
- static void mad(
- uint32_t &d,
- const uint32_t &a,
- const uint32_t &b,
- const uint32_t &c)
- {
- asm volatile ( "dp4a.u32.u32 %0, %1, %2, %3;\n"
- : "=r"(d) : "r"(a), "r"(b), "r"(c));
- }
-};
-
-
-#endif // Specializations only enabled for Pascal SM610+
-
-
-} // namespace gemm
-} // namespace cutlass
-
diff --git a/cutlass/gemm/epilogue_function.h b/cutlass/gemm/epilogue_function.h
deleted file mode 100644
index 714dd817..00000000
--- a/cutlass/gemm/epilogue_function.h
+++ /dev/null
@@ -1,104 +0,0 @@
-/******************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright
- * notice, this list of conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright
- * notice, this list of conditions and the following disclaimer in the
- * documentation and/or other materials provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the
- * names of its contributors may be used to endorse or promote products
- * derived from this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- ******************************************************************************/
-
-#pragma once
-
-/**
- * \file
- * Epilogue operation to compute final output
- */
-
-namespace cutlass {
-namespace gemm {
-
- //// Used by GEMM to compute the final result C <= alpha * accumulator + beta * C
- template <
- typename accum_t,
- typename output_t,
- typename scalar_t
- >
- class blas_scaled_epilogue
- {
- public:
-
- scalar_t alpha;
- scalar_t beta;
-
- inline __device__ __host__
- blas_scaled_epilogue(
- scalar_t alpha,
- scalar_t beta)
- :
- alpha(alpha),
- beta(beta)
- {}
-
-
- /// Epilogue operator
- inline __device__ __host__
- output_t operator()(
- accum_t accumulator,
- output_t c,
- size_t idx) const
- {
- return output_t(alpha * scalar_t(accumulator) + beta * scalar_t(c));
- }
-
-
- /// Epilogue operator
- inline __device__ __host__
- output_t operator()(
- accum_t accumulator,
- size_t idx) const
- {
- return output_t(alpha * scalar_t(accumulator));
- }
-
- /**
- * Configure epilogue as to whether the thread block is a secondary
- * accumulator in an inter-block k-splitting scheme
- */
- inline __device__
- void set_secondary_accumulator()
- {
- beta = scalar_t(1);
- }
-
-
- /// Return whether the beta-scaled addend needs initialization
- inline __device__
- bool must_init_addend()
- {
- return (beta != scalar_t(0));
- }
- };
-
-
-
-
-} // namespace gemm
-} // namespace cutlass
diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h
new file mode 100644
index 00000000..0ca093ff
--- /dev/null
+++ b/cutlass/gemm/gemm.h
@@ -0,0 +1,319 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements a software-pipelined efficient GEMM.
+*/
+#pragma once
+
+#if !defined(__CUDACC_RTC__)
+#include
+#endif
+
+#include
+#include
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+__global__ void gemm_kernel(typename Gemm_::Params params) {
+ // Declare shared memory.
+ __shared__ typename Gemm_::SharedStorage shared_storage;
+
+ // Construct the GEMM object.
+ Gemm_ gemm(params, shared_storage);
+ // Run GEMM.
+ gemm.multiply_add();
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct GemmDesc {
+ /// The dimensions of the GEMM.
+ Index_ m, n, k;
+ /// The alpha/beta scaling values.
+ Scalar_ alpha, beta;
+ /// The source matrix A.
+ void const* d_a;
+ /// The stride for A.
+ Index_ lda;
+ /// The source matrix B.
+ void const* d_b;
+ /// The stride for B.
+ Index_ ldb;
+ /// The source matrix C.
+ void const* d_c;
+ /// The stride for C.
+ Index_ ldc;
+ /// The destination matrix D.
+ void* d_d;
+ /// The stride for D.
+ Index_ ldd;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct Gemm {
+ /// This class.
+ typedef Gemm This_;
+ /// The traits.
+ typedef GemmTraits_ Traits;
+ /// The shared storage.
+ typedef typename Traits::SharedStorage SharedStorage;
+
+ /// The scalar for A.
+ typedef typename Traits::ScalarA ScalarA;
+ /// The scalar for B.
+ typedef typename Traits::ScalarB ScalarB;
+ /// The scalar in the epilogue.
+ typedef typename Traits::Epilogue::Scalar ScalarEpilogue;
+ /// The scalar for C.
+ typedef typename Traits::Epilogue::ScalarC ScalarC;
+ /// The scalar for D.
+ typedef typename Traits::Epilogue::ScalarD ScalarD;
+ /// The index.
+ typedef typename Traits::Index Index;
+
+ /// The number of threads.
+ static int const kThreads = Traits::GemmConfig::kThreads;
+
+ /// The params.
+ struct Params : public Traits::Params {
+ CUTLASS_HOST_DEVICE int initialize(Index m,
+ Index n,
+ Index k,
+ ScalarEpilogue alpha,
+ ScalarA const* d_a,
+ Index lda,
+ ScalarB const* d_b,
+ Index ldb,
+ ScalarEpilogue beta,
+ ScalarC const* d_c,
+ Index ldc,
+ ScalarD* d_d,
+ Index ldd) {
+ GemmDesc desc;
+ desc.m = m;
+ desc.n = n;
+ desc.k = k;
+ desc.alpha = alpha;
+ desc.beta = beta;
+ desc.d_a = reinterpret_cast(d_a);
+ desc.lda = lda;
+ desc.d_b = reinterpret_cast(d_b);
+ desc.ldb = ldb;
+ desc.d_c = reinterpret_cast(d_c);
+ desc.ldc = ldc;
+ desc.d_d = reinterpret_cast(d_d);
+ desc.ldd = ldd;
+ return Traits::Params::initialize(desc);
+ }
+ };
+
+#if !defined(__CUDACC_RTC__)
+ /// Launch the kernel.
+ static __host__ cudaError_t launch(Params const& params,
+ cudaStream_t stream = cudaStreamDefault) {
+ // Setup the grid.
+ dim3 grid;
+ grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
+ grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
+
+ // The number of threads.
+ dim3 block;
+ block.x = kThreads;
+
+ // Launch the kernel.
+ void const* params_ = reinterpret_cast(¶ms);
+
+ return cudaLaunchKernel(reinterpret_cast(&gemm_kernel