Checkpointing CUTLASS 1.1 release.
@ -1,6 +1,22 @@
|
||||
# NVIDIA CUTLASS Changelog
|
||||
|
||||
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
|
||||
|
||||
## 1.1.0 (2018-09-19)
|
||||
* Turing Features
|
||||
* WMMA GEMM targeting TensorCores - INT8, INT4, INT1
|
||||
* Batched Strided GEMM
|
||||
* Threadblock rasterization strategies
|
||||
* Improved performance for adverse problem sizes and data layouts
|
||||
* Extended CUTLASS Core comonents
|
||||
* Tensor views support arbitrary matrix and tensor layouts
|
||||
* Zip iterators for structuring multiple data streams
|
||||
* Enhanced CUTLASS utilities
|
||||
* Reference code for tensor operations in host and device code
|
||||
* Added HostMatrix<> for simplified matrix creation
|
||||
* Examples
|
||||
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
|
||||
|
||||
## 1.0.1 (2018-06-11)
|
||||
|
||||
* Intra-threadblock reduction added for small threadblock tile sizes
|
||||
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
|
||||
@ -55,11 +55,21 @@ endif()
|
||||
find_package(CUDA)
|
||||
find_package(Doxygen QUIET)
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Configure CMake variables
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
find_library(CUBLAS_LIBRARY cublas HINTS
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib64
|
||||
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
|
||||
|
||||
# By default we want to build in Release mode to ensure that we're getting best performance
|
||||
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
|
||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
|
||||
# We do support Debug or Release builds
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release")
|
||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
|
||||
endif()
|
||||
|
||||
if(WIN32)
|
||||
@ -68,27 +78,59 @@ if(WIN32)
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
# Enable more warnings and treat as errors
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
|
||||
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
# Disable warning on Unicode characters
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
|
||||
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
# Disable excess x86 floating point precision that can lead to results being labeled incorrectly
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
|
||||
|
||||
# Verbose option
|
||||
if (${CUTLASS_NVCC_VERBOSE})
|
||||
string(APPEND NVCC_FLAGS " -v")
|
||||
endif()
|
||||
endif(WIN32)
|
||||
|
||||
# Configure CUDA options
|
||||
set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
|
||||
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
|
||||
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
|
||||
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
|
||||
|
||||
#
|
||||
# NOTE: running with asan and CUDA requires the following environment variable:
|
||||
#
|
||||
# ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0
|
||||
#
|
||||
# without the above environment setting, an error like the following may be generated:
|
||||
#
|
||||
# *** Error: Could not detect active GPU device ID [out of memory]
|
||||
# ...
|
||||
# ==9149==ERROR: LeakSanitizer: detected memory leaks
|
||||
# ...
|
||||
#
|
||||
if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
|
||||
string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer")
|
||||
string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
|
||||
endif()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
# Configure CUDA build options
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
# Set NVCC arguments
|
||||
foreach(ARCH ${CUTLASS_NVCC_ARCHS})
|
||||
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
|
||||
if(CUTLASS_NVCC_EMBED_CUBIN)
|
||||
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
|
||||
endif()
|
||||
if(CUTLASS_NVCC_EMBED_PTX)
|
||||
string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=compute_${ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
if (CUTLASS_NVCC_KEEP)
|
||||
string(APPEND NVCC_FLAGS " -keep")
|
||||
endif()
|
||||
@ -99,11 +141,8 @@ else()
|
||||
string(APPEND NVCC_FLAGS " -lineinfo")
|
||||
endif()
|
||||
|
||||
if (UNIX)
|
||||
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
|
||||
endif()
|
||||
|
||||
string(APPEND NVCC_FLAGS_DEBUG " -g")
|
||||
string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3")
|
||||
string(APPEND NVCC_FLAGS_RELEASE " -O3")
|
||||
|
||||
# define NDEBUG for release mode to disable assertions
|
||||
@ -111,11 +150,13 @@ string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
|
||||
|
||||
if (CUTLASS_NATIVE_CUDA)
|
||||
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
|
||||
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
|
||||
else()
|
||||
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
|
||||
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
|
||||
set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO})
|
||||
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
|
||||
endif()
|
||||
|
||||
@ -128,6 +169,11 @@ file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
|
||||
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
|
||||
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
|
||||
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
|
||||
###################################################################################################
|
||||
#
|
||||
# Define build targets
|
||||
#
|
||||
###################################################################################################
|
||||
|
||||
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
|
||||
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
|
||||
@ -156,9 +202,9 @@ add_custom_target(cutlass_ide SOURCES
|
||||
if (DOXYGEN_FOUND)
|
||||
# DOT is available. Enable graph generation in the documentation
|
||||
if (DOXYGEN_DOT_EXECUTABLE)
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
|
||||
else()
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
|
||||
set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
|
||||
endif()
|
||||
|
||||
if (CUTLASS_ENABLE_DOXYGEN_DOT)
|
||||
@ -177,6 +223,5 @@ if (DOXYGEN_FOUND)
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
#add_subdirectory(examples/gemm)
|
||||
add_subdirectory(tools)
|
||||
add_subdirectory(examples)
|
||||
|
||||
311
CUTLASS.md
Normal file
@ -0,0 +1,311 @@
|
||||

|
||||
|
||||
# CUTLASS
|
||||
|
||||
This document is intended to accompany the CUTLASS source code, to describe the interaction between
|
||||
CUTLASS core components, and to identify their role in implementing GEMM computations efficiently in CUDA.
|
||||
|
||||
1. [Design Patterns](#S-design-patterns)
|
||||
2. [General Matrix Multiply](#S-general-matrix-multiply)
|
||||
3. [Core Components](#S-core-components)
|
||||
4. [Utilities](#S-utilities)
|
||||
|
||||
# <a name="S-design-patterns"></a> 1. Design Patterns
|
||||
|
||||
CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
|
||||
flexible composition that an be easily applied to solve new problems related to Deep Learning and
|
||||
linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
|
||||
a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
|
||||
design patterns are necessary to yield a composable structure while also satisfying these performance
|
||||
objectives. This section is intended to provide more detail.
|
||||
|
||||
* [Sequencing and Nesting](#S-patterns-sequencing-nesting)
|
||||
* [Tiles and Iterators](#S-patterns-tiles-iterators)
|
||||
* [Host-side Params](#S-patterns-host-side-params)
|
||||
* [Composable Shared Memory](#S-patterns-composable-shared-memory)
|
||||
|
||||
## <a name="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives
|
||||
|
||||
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
|
||||
|
||||
## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators
|
||||
|
||||
Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
|
||||
specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
|
||||
|
||||
_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
|
||||
elements in memory as well as traversing over a collection. GEMM kernels in CUTLASS depend on accessing
|
||||
a sequence of tiles from global memory, from shared memory, and in registers. Consequently, _tile iterators_
|
||||
are prevalent throughout the CUTLASS implementation.
|
||||
|
||||
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
|
||||
|
||||
## <a name="S-patterns-host-side-params"></a> Host-side Params structure
|
||||
|
||||
Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer.
|
||||
|
||||
CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread.
|
||||
|
||||
The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
|
||||
|
||||
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
|
||||
|
||||
|
||||
## <a name="S-patterns-composable-shared-memory"></a> Composable shared memory allocation
|
||||
|
||||
Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held in shared memory. Any object requiring shared memory storage for itself or its data members should define a child structure called SharedStorage. This holds data needed by the class and also instantiates SharedStorage objects for each data member.
|
||||
|
||||
To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall SMEM capacity.
|
||||
|
||||
## <a name="S-patterns-loop-unrolling"></a> Loop Unrolling
|
||||
|
||||
CUTLASS requires tiles of data to be stored in registers for high-bandwidth access. Simultaneously, high-throughput math instructions
|
||||
must be issued concurrently with memory instructions to hide latency with relatively few concurrent threads. These objectives are
|
||||
achieved by unrolling loops whose iteration counts are known at compile time.
|
||||
|
||||
Consequently, most loops within the CUTLASS GEMM implementation are specified by constant values and template arguments. The CUDA compiler
|
||||
is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule.
|
||||
|
||||
## <a name="S-patterns-loop-unrolling"></a> Templates
|
||||
|
||||
CUDA C++ templates and modern generic programming techniques enable CUTLASS device code to span a large design space.
|
||||
|
||||
This design space includes:
|
||||
* Mixed precision arithmetic and data storage
|
||||
* Kernels specialized for layout and problem size
|
||||
* Support for kernel fusion
|
||||
|
||||
Moreover, templates provided a structured approach to collecting compile-time constants such as tile dimensions. These
|
||||
must be template arguments to target static array allocation and take advantage of loop unrolling, constant folding,
|
||||
and function inlining.
|
||||
|
||||
# <a name="S-general-matrix-multiply"></a> 2. General Matrix Multiply
|
||||
|
||||
The following figure illustrates the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right.
|
||||
|
||||

|
||||
|
||||
## Threadblock-level GEMM
|
||||
|
||||
The CUTLASS GEMM kernel partitions the _C_ matrix into a 2D tiling of threadblocks.
|
||||
Each threadblock computes a matrix product whose outer dimensions _M_ and _N_ are compile-time constants. The
|
||||
GEMM's _K_ dimension is partitioned into tiles and iterated over by the GEMM _mainloop_. The shape of the matrix
|
||||
multiply operation performed by each iteration of the mainloop is referred to as _OutputTile_.
|
||||
|
||||
The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
|
||||
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
|
||||
buffer in shared memory is performed by a _GlobalLoadIterator_.
|
||||
|
||||
**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
|
||||
|
||||
The Global Load Stream template contains members defined by the following templates:
|
||||
|
||||
* [GemmGlobalIteratorAb](cutlass/gemm/gemm_global_tile.h)
|
||||
* [Transformer](cutlass/convert.h)
|
||||
* [GemmSharedStoreTileAb](cutlass/gemm/gemm_shared_tile.h)
|
||||
|
||||
## Warp-level GEMM
|
||||
|
||||
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
|
||||
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
|
||||
|
||||
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
|
||||
|
||||
* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)
|
||||
|
||||
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
|
||||
|
||||
* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)
|
||||
|
||||
## Thread-level GEMM
|
||||
|
||||
SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
|
||||
procedures.
|
||||
|
||||
* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
|
||||
* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
|
||||
* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)
|
||||
|
||||
## Epilogue
|
||||
|
||||
The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
|
||||
|
||||
1. [Transformer](cutlass/convert.h) for converting the data types of accumulator elements
|
||||
2. [GemmSharedStoreTileD](cutlass/gemm/gemm_shared_tile.h) to store to shared memory specialized to the accumulator layout.
|
||||
3. [GemmSharedLoadTileD](cutlass/gemm/gemm_shared_tile.h) to load the data from shared memory.
|
||||
4. [GemmGlobalIteratorC](cutlass/gemm/gemm_global_tile.h) to load a tile from global memory.
|
||||
5. A [functor](cutlass/gemm/linear_scaling.h) to compute an element-wise operation on the matrix product and source data (such as alpha*AB+beta*C).
|
||||
6. [GemmGlobalIteratorD](cutlass/gemm/gemm_global_tile.h) to write the output to global memory.
|
||||
|
||||
## GEMM Traits
|
||||
|
||||
[**cutlass::gemm::GemmTraits**](cutlass/gemm/gemm_traits.h) collects the structural properties of a complete GEMM computation into a single template class. As a result, the Traits classes encapsulate the the iterators and transformers for all supported GEMM operands and layouts. Low-level details needed by Traits (such as scalar types for operands, thread-block tile size, number of scalar elements per memory access within each phase, number of stages in shared memory, as well as other implementation-specific properties of the GEMM computation) are specified in class [**cutlass::gemm::GemmConfig**](cutlass/gemm/gemm_config.h).
|
||||
|
||||
|
||||
# <a name="S-core-components"></a> 3. Core Components
|
||||
|
||||
CUTLASS GEMM kernels are implemented by a set of Core components for interacting with mathematical tensor and matrix
|
||||
objects as well as constructing efficient CUDA kernels.
|
||||
|
||||
* [Tensor views](#S-core-tensor-views)
|
||||
* [Shape](#S-core-shape)
|
||||
* [Tile structure](#S-core-tile-structure)
|
||||
* [Fragment](#S-core-fragment)
|
||||
* [Predicate vector](#S-core-predicate-vector)
|
||||
|
||||
## <a name="S-core-tensor-views"></a> Tensor View
|
||||
|
||||
Matrices and tensors are typically represented as n-D arrays held in linear memory with a single base pointer and a stride vector. Element _i_ of the stride vector indicates the offset in linear memory between consecutive elements in dimension i. Consequently, the linear offset for an arbitrary element specified as an n-tuple may be computed as the dot product of the coordinate and the stride vector.
|
||||
|
||||
CUTLASS provides abstractions for interacting with multidimension tensors in device memory.
|
||||
Consequently, we define a hierarchy of pointer-like types for referencing tensors.
|
||||
|
||||
`T *` - raw pointer to elements of type T
|
||||
|
||||
`cutlass::TensorRef<T, Rank>` - reference to a tensor of elements of type T and given rank. Includes a mapping function and associated stride vector for accessing elements in linear memory.
|
||||
|
||||
`cutlass::TensorView<T, Rank>` - extends `TensorRef<>` by adding bounds information. This is a complete mathematical object which may be used as the argument to CUTLASS functions.
|
||||
|
||||
The above provide an identity maping of a logical index space to linear memory. An element
|
||||
at logical coordinate X has an offset computed as follows:
|
||||
```
|
||||
offset = dot(X, stride)
|
||||
```
|
||||
where `dot()` computes the inner product of X and a vector of "strides."
|
||||
|
||||
CUTLASS 1.1 introduces a mapping function and an additional "storage rank" to offer a flexible way to
|
||||
map the logical index space of the tensor to memory. The mapping function maps a coordinate
|
||||
of rank _R_ to an index space of rank _S_. The linear offset is computed as:
|
||||
```
|
||||
offset = dot( MapFunc(X), stride )
|
||||
```
|
||||
where stride is a vector of rank _S_.
|
||||
|
||||
CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and
|
||||
correctness. Consequently, we enforce a constraint on the strides used by mapping functions
|
||||
such that:
|
||||
|
||||
1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in
|
||||
that rank are consecutive in linear memory.
|
||||
|
||||
2. The fastest changing rank is always last in the stride vector and not explicitly stored.
|
||||
|
||||
Thus, the stride vector used by mapping functions has length of one fewer than the rank of the
|
||||
storage tensor. These constraints are consistent with the BLAS interface of passing matrices as
|
||||
a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors
|
||||
whose fastest changing dimension is 1, and only the strided dimension is explicitly represented.
|
||||
|
||||
A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor,
|
||||
to linear memory such that (1.) elements in the same column are consecutive in memory
|
||||
(column-major), or (2.) elements in the same row are consecutive (row-major). These can be
|
||||
accomplished by two different mapping functions whose stride vector is length=2. The first
|
||||
element is the "leading dimension."
|
||||
|
||||
The requirement that the fastest-changing stride always be of unit size need not be a limitation.
|
||||
To implement "sparse" computations or matrix operations in which matrix elements have arbitrary
|
||||
stride along both row and column, define a mapping function whose storage rank is 3. This permits
|
||||
two elements of the stride vector to have a non-unit value.
|
||||
|
||||
`cutlass::TensorView<>` extends this concept by including a size vector to specify the bounds of
|
||||
the index space. The value of each coordinate in the size vector defines the half-open range of
|
||||
indices whose smallest value is zero.
|
||||
|
||||
## <a name="S-core-shape"></a> Shape
|
||||
|
||||
To avoid complicated template metaprogramming, CUTLASS targets fixed compile-time tile sizes specified
|
||||
by a four-dimensional template `cutlass::Shape<>`. This defines the following dimensions, mirroring
|
||||
the NHWC tensor format used for convolution in Deep Learning frameworks.
|
||||
|
||||
- `D`: depth of tensor
|
||||
- `H`: first strided dimension
|
||||
- `W`: contiguous sequence of tensor elements
|
||||
- `C`: number of channels, usually used for vectorized access
|
||||
|
||||
Template specializations of `Shape` appear as arguments to numerous dependent template classes which
|
||||
must specify compile-time constant tile sizes.
|
||||
|
||||
## <a name="S-core-tile-structure"></a> Tile Structure
|
||||
|
||||
Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
|
||||
threads to the problem space. For example, the CUTLASS GEMM
|
||||
|
||||
Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
|
||||
members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
|
||||
lattice, partition its elements among a collection of threads, and map each unique thread ID to a unique
|
||||
offset.
|
||||
|
||||
- _Tile_ (concept `Shape<>`) - describes the dimensions of the tile in terms of scalar elements
|
||||
- _Delta_ (concept `Shape<>`) - describes the distance along each logical dimension between items
|
||||
- _Iterations_ (concept `Shape<>`) - describes the number of items along each logical dimension
|
||||
- _ThreadOffset_ (concept _functor_) - implements `Coord<4> operator()() const` to determine a thread's
|
||||
initial offset in the logical 4-D coordinate space
|
||||
|
||||
The following figure illustrates the CUTLASS tile structure. The overall shape, 16-by-16, is partitioned into
|
||||
vectors of length two among 32 threads. The elements stored by thread 9 are highlighted.
|
||||
|
||||
<img src="/media/images/cutlass-tile-structure.png" alt="CUTLASS tile structure" width="30%" />
|
||||
|
||||
The `cutlass::TileTraits<>` definition that describes this arrangement may be defined as follows:
|
||||
|
||||
```
|
||||
struct ExampleTileTraits {
|
||||
|
||||
/// Overall shape of tile
|
||||
typedef Shape<1, 16, 16, 1> Tile;
|
||||
|
||||
/// Distance along each dimension of accesses
|
||||
typedef Shape<1, 4, 1, 1> Delta;
|
||||
|
||||
/// Number of memory accesses performed by each thread
|
||||
typedef Shape<1, 4, 1, 1> Iterations;
|
||||
|
||||
/// Offset function - maps each thread to a unique starting offset within the 4D tile
|
||||
struct ThreadOffset {
|
||||
|
||||
CUTLASS_DEVICE Coord<4> operator()() const {
|
||||
|
||||
typdef Shape<1, 16, 8, 2> Vectorized;
|
||||
|
||||
return make_Coord(
|
||||
0, // depth "D" dimension
|
||||
threadIdx.x / Vectorized::kW, // horisontal "H" dimension - first strided dimension
|
||||
threadIdx.x % Vectorized::kW, // vertical "W" dimension - contiguous dimension
|
||||
0
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
```
|
||||
|
||||
## <a name="S-core-tile-iterator"></a> Tile Iterator
|
||||
|
||||
The iterator design pattern provides an abstraction for accessing the items in a collection in sequence. Basic
|
||||
operators defined by iterators consist of accessing an item - either a load or store - followed by traversal to
|
||||
the next item in sequence.
|
||||
|
||||
<img src="/media/images/cutlass-tile-iteration.png" alt="CUTLASS tile access and traversal" width="50%" />
|
||||
|
||||
To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
|
||||
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
|
||||
|
||||
The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
|
||||
|
||||
## <a name="S-core-fragment"></a> Fragment
|
||||
|
||||
A fragment is analogous to `std::array<>` in that it is a constant-sized array of elements. Typically backed by storage in the SM's register file, CUTLASS `Fragment<>` objects are used to store tiles. For threadblock- and warp-scope operations, the contents of these tiles are distributed across the partipcipating threads. In such cases, a thread's `Fragment<>` contains the part of the tile held by that thread.
|
||||
|
||||
## <a name="S-core-predicate-vector"></a> Predicate Vector
|
||||
|
||||
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
|
||||
|
||||
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
|
||||
|
||||
|
||||
# <a name="S-utilities"></a> 4. Utilities
|
||||
|
||||
CUTLASS implements efficient matrix multiply computations on GPUs. It is accompanied by an extensive utility
|
||||
framework offering features such as:
|
||||
|
||||
* [cutlass::half_t](tools/util/half.h) - a host-side half-precision type
|
||||
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
|
||||
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
|
||||
2
Doxyfile
@ -58,7 +58,7 @@ PROJECT_LOGO =
|
||||
# entered, it will be relative to the location where doxygen was started. If
|
||||
# left blank the current directory will be used.
|
||||
|
||||
OUTPUT_DIRECTORY = docs
|
||||
OUTPUT_DIRECTORY = doxygen
|
||||
|
||||
# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
|
||||
# directories (in 2 levels) under the output directory of each output format and
|
||||
|
||||
79
README.md
@ -1,10 +1,10 @@
|
||||

|
||||
|
||||
# CUTLASS 1.0
|
||||
# CUTLASS 1.1
|
||||
|
||||
_CUTLASS 1.0.1 - June 2018_
|
||||
_CUTLASS 1.1.0 - September 2018_
|
||||
|
||||
CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing
|
||||
CUTLASS 1.1 is a collection of CUDA C++ template abstractions for implementing
|
||||
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
|
||||
It incorporates strategies for hierarchical decomposition and data movement similar
|
||||
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
|
||||
@ -22,14 +22,27 @@ point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targe
|
||||
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
|
||||
and beyond.
|
||||
|
||||
CUTLASS 1.0 has changed substantially from our preview release described in
|
||||
the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda)
|
||||
post. We have decomposed the structure of the GEMM computation into deeper, structured
|
||||
primitives for loading data, computing predicate masks, streaming data at each level of
|
||||
the GEMM hierarchy, and updating the output matrix.
|
||||
CUTLASS 1.1 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
|
||||
[Doxygen documentation](https://nvidia.github.io/cutlass).
|
||||
We describe the structure of an efficient GEMM in our talk at the
|
||||
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
|
||||
CUTLASS 1.0 is described in the [Doxygen documentation](https://nvidia.github.io/cutlass)
|
||||
and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
|
||||
# What's New in CUTLASS 1.1
|
||||
|
||||
* [CUTLASS Documentation](CUTLASS.md)
|
||||
* [Examples](examples/)
|
||||
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
|
||||
* Turing Features
|
||||
* [WMMA GEMM targeting TensorCores](tools/test/unit/gemm/wmma_integer_gemm.cu) - INT8, INT4, INT1
|
||||
* [Batched Strided GEMM](tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu)
|
||||
* [Threadblock rasterization strategies](tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu)
|
||||
* Improved performance for adverse problem sizes and data layouts
|
||||
* Extended CUTLASS Core comonents
|
||||
* Tensor views support arbitrary matrix and tensor layouts
|
||||
* Zip iterators for structuring multiple data streams
|
||||
* Enhanced CUTLASS utilities
|
||||
* [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
|
||||
* Added `HostMatrix<>` for simplified matrix creation
|
||||
|
||||
# Performance
|
||||
|
||||
@ -39,11 +52,11 @@ CUTLASS primitives are very efficient. When used to construct device-wide GEMM
|
||||
they exhibit performance comparable to cuBLAS for scalar GEMM
|
||||
computations. The above figure shows CUTLASS performance relative to cuBLAS
|
||||
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
|
||||
when compiled with CUDA 9.2.
|
||||
when compiled with CUDA 10.0.
|
||||
|
||||
# Compatibility
|
||||
|
||||
CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
|
||||
CUTLASS requires CUDA 9 but performs best with [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
|
||||
|
||||
|**Operating System** | **Compiler** |
|
||||
|-----------------|----------|
|
||||
@ -63,7 +76,7 @@ any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
|
||||
|NVIDIA Tesla P100|
|
||||
|NVIDIA Tesla V100|
|
||||
|NVIDIA TitanV|
|
||||
|
||||
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|
|
||||
|
||||
# Building CUTLASS
|
||||
|
||||
@ -79,7 +92,7 @@ $ git submodule update --init --recursive
|
||||
```
|
||||
|
||||
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify
|
||||
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
|
||||
the architectures to build CUTLASS for by changing the CMake configuration setting
|
||||
`CUTLASS_NVCC_ARCHS`.
|
||||
|
||||
@ -107,13 +120,12 @@ $ ./tools/test/unit/cutlass_unit_test
|
||||
...
|
||||
...
|
||||
[----------] Global test environment tear-down
|
||||
[==========] 481 tests from 24 test cases ran. (5954 ms total)
|
||||
[ PASSED ] 481 tests.
|
||||
[==========] 946 tests from 57 test cases ran. (10812 ms total)
|
||||
[ PASSED ] 946 tests.
|
||||
```
|
||||
|
||||
All tests should pass, though the exact number of tests may vary over time.
|
||||
|
||||
|
||||
# Project Structure
|
||||
|
||||
CUTLASS is arranged as a header-only library with several example test programs
|
||||
@ -128,28 +140,41 @@ templates in the cutlass/gemm directory.
|
||||
|
||||
```
|
||||
cutlass/
|
||||
gemm/
|
||||
util/
|
||||
<core API components>
|
||||
gemm/
|
||||
util/
|
||||
<core API components>
|
||||
```
|
||||
|
||||
Several tools and test programs are also distributed with the CUTLASS library. They are
|
||||
contained in the following directories.
|
||||
|
||||
```
|
||||
examples/
|
||||
00_basic_gemm/
|
||||
01_tensor_view/
|
||||
02_cutlass_utilities/
|
||||
03_batched_gemm/
|
||||
04_tile_iterator/
|
||||
05_wmma_gemm/
|
||||
tools/
|
||||
test/
|
||||
unit/
|
||||
core/
|
||||
gemm/
|
||||
perf/
|
||||
util/
|
||||
<utilities>
|
||||
test/
|
||||
unit/
|
||||
core/
|
||||
gemm/
|
||||
perf/
|
||||
util/
|
||||
reference/
|
||||
device/
|
||||
host/
|
||||
<utilities>
|
||||
```
|
||||
|
||||
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
|
||||
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
|
||||
|
||||
The `tools/util` directory contains CUTLASS utilities including reference implementations of GEMM and
|
||||
several element-wise tensor operations.
|
||||
|
||||
# Performance Profiling
|
||||
|
||||
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
function formatFiles {
|
||||
for f in `find "$1" -type f -name "*.$2"` ; do
|
||||
COMMAND="clang-format -i $f"
|
||||
echo $COMMAND
|
||||
$COMMAND
|
||||
done
|
||||
}
|
||||
|
||||
formatFiles "cutlass" "h"
|
||||
formatFiles "tools/test" "h"
|
||||
formatFiles "tools/test" "cpp"
|
||||
formatFiles "tools/util" "h"
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
160
cutlass/coord.h
@ -28,7 +28,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -44,20 +45,27 @@ struct Identity {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically-sized array specifying Coords within a tensor
|
||||
template <int N_>
|
||||
template <int Rank_, typename Index_ = int>
|
||||
struct Coord {
|
||||
//
|
||||
// Type and constant definitions
|
||||
//
|
||||
|
||||
static int const N = N_;
|
||||
/// Number of elements in Coord
|
||||
static int const kRank = Rank_;
|
||||
|
||||
/// Number of elements in Coord, aliased for compatibility
|
||||
static int const N = Rank_;
|
||||
|
||||
/// Index type used to store elements
|
||||
typedef Index_ Index;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Indices
|
||||
int idx[N];
|
||||
Index idx[kRank];
|
||||
|
||||
//
|
||||
// Methods
|
||||
@ -65,25 +73,72 @@ struct Coord {
|
||||
|
||||
/// Default ctor initializes uniformly
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int value = 0) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Coord(Index value = 0) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(int _idx[]) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Coord(Index _idx[]) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = _idx[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from an array of integers
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord(Coord<kRank> const &coord) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = coord[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a slice of the Coord which may be larger or smaller in rank
|
||||
/// than this.
|
||||
template <int Slice>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Slice> slice(int start = 0, Index identity = 0) const {
|
||||
Coord<Slice> result;
|
||||
for (int i = 0; i < Slice; ++i) {
|
||||
if (i + start < kRank) {
|
||||
slice[i] = idx[i + start];
|
||||
}
|
||||
else {
|
||||
slice[i] = identity;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns true if Coord is non-zero.
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator bool() const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (idx[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if Coord is uniformly zero.
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!() const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (idx[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator+(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] + b.idx[i];
|
||||
}
|
||||
return c;
|
||||
@ -93,7 +148,7 @@ struct Coord {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator-(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] - b.idx[i];
|
||||
}
|
||||
return c;
|
||||
@ -103,7 +158,7 @@ struct Coord {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator*(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] * b.idx[i];
|
||||
}
|
||||
return c;
|
||||
@ -113,7 +168,7 @@ struct Coord {
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord operator/(Coord const& b) const {
|
||||
Coord c;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
c.idx[i] = idx[i] / b.idx[i];
|
||||
}
|
||||
return c;
|
||||
@ -122,7 +177,7 @@ struct Coord {
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator+=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] += b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
@ -131,7 +186,7 @@ struct Coord {
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator-=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] -= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
@ -140,7 +195,7 @@ struct Coord {
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator*=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] *= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
@ -149,22 +204,22 @@ struct Coord {
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& operator/=(Coord const& b) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] /= b.idx[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
|
||||
CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; }
|
||||
|
||||
/// Member access operator
|
||||
CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
|
||||
CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; }
|
||||
|
||||
/// Computes the dot product of two Coord instances
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
@ -174,7 +229,7 @@ struct Coord {
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
|
||||
T sum = T(0);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
sum += idx[i] * b.idx[i];
|
||||
}
|
||||
return sum;
|
||||
@ -182,29 +237,29 @@ struct Coord {
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int& at() {
|
||||
CUTLASS_HOST_DEVICE Index& at() {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int& at(int dim) { return idx[dim]; }
|
||||
Index& at(int dim) { return idx[dim]; }
|
||||
|
||||
/// Gets the index of a given Coord element
|
||||
template <int Dim>
|
||||
CUTLASS_HOST_DEVICE int const& at() const {
|
||||
CUTLASS_HOST_DEVICE Index const& at() const {
|
||||
return idx[Dim];
|
||||
}
|
||||
|
||||
/// Access via index; may limit unrolling potential
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& at(int dim) const { return idx[dim]; }
|
||||
Index const& at(int dim) const { return idx[dim]; }
|
||||
|
||||
/// Determines if two Coord<> objects are equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator==(Coord<N> const& b) const {
|
||||
bool operator==(Coord<kRank> const& b) const {
|
||||
bool equal = true;
|
||||
for (int i = 0; equal && i < N; ++i) {
|
||||
for (int i = 0; equal && i < kRank; ++i) {
|
||||
equal = (idx[i] == b.idx[i]);
|
||||
}
|
||||
return equal;
|
||||
@ -212,12 +267,12 @@ struct Coord {
|
||||
|
||||
/// Not equal
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator!=(Coord<N> const& b) const { return !(*this == b); }
|
||||
bool operator!=(Coord<kRank> const& b) const { return !(*this == b); }
|
||||
|
||||
/// Clamps a coordinate to a range specified by maximum and minimum values
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord& clamp(Coord<N> const& max, Coord<N> const& min = Coord<N>()) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Coord& clamp(Coord<kRank> const& max, Coord<kRank> const& min = Coord<kRank>()) {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
|
||||
}
|
||||
return *this;
|
||||
@ -225,13 +280,35 @@ struct Coord {
|
||||
|
||||
/// Returns the product of all elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
int count() const {
|
||||
int product = idx[0];
|
||||
for (int i = 1; i < N; ++i) {
|
||||
Index count() const {
|
||||
Index product = idx[0];
|
||||
for (int i = 1; i < kRank; ++i) {
|
||||
product *= idx[i];
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
/// Less than operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<(Coord<kRank> const &b) const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (!(idx[i] < b[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Less than or equals operator
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator<=(Coord<kRank> const &b) const {
|
||||
for (int i = 0; i < kRank; ++i) {
|
||||
if (!(idx[i] <= b[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -266,21 +343,10 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
|
||||
|
||||
/// Getter
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
|
||||
template <typename Shape_>
|
||||
CUTLASS_HOST_DEVICE Coord<3> make_Coord_from_shape() {
|
||||
return make_Coord(Shape_::kD, Shape_::kH, Shape_::kW);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -22,8 +22,6 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/*! \file
|
||||
\brief Helpers for printing cutlass/core objects
|
||||
*/
|
||||
@ -33,12 +31,96 @@
|
||||
#include <iosfwd>
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Rank>
|
||||
std::ostream& operator<<(std::ostream& out, cutlass::Coord<Rank> const& coord) {
|
||||
std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
|
||||
for (int i = 0; i < Rank; ++i) {
|
||||
out << (i ? ", " : "") << coord.idx[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
|
||||
template <typename T>
|
||||
struct ScalarIO {
|
||||
|
||||
/// Value to print
|
||||
T value;
|
||||
|
||||
/// Default ctor
|
||||
ScalarIO() { }
|
||||
|
||||
/// Constructs from a value
|
||||
ScalarIO(T value): value(value) {}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Default printing to ostream
|
||||
template <typename T>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<T> const &scalar) {
|
||||
return out << scalar.value;
|
||||
}
|
||||
|
||||
/// Printing to ostream of int8_t as integer rather than character
|
||||
template <>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<int8_t> const &scalar) {
|
||||
return out << int(scalar.value);
|
||||
}
|
||||
|
||||
/// Printing to ostream of uint8_t as integer rather than character
|
||||
template <>
|
||||
inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scalar) {
|
||||
return out << unsigned(scalar.value);
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 1b elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::bin1_t, 32> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 32; i++) {
|
||||
out << int(scalar.value[i]);
|
||||
out << ((i != 31) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 4b signed integer elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::int4_t, 8> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out << int(scalar.value[i]);
|
||||
out << ((i != 7) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
/// Printing to ostream of vector of 4b unsigned integer elements
|
||||
template <>
|
||||
inline std::ostream &operator<<(
|
||||
std::ostream &out,
|
||||
ScalarIO<cutlass::Vector<cutlass::uint4_t, 8> > const &scalar) {
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out << unsigned(scalar.value[i]);
|
||||
out << ((i != 7) ? ", " : "");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -47,7 +47,9 @@
|
||||
// CUTLASS_DEVICE is an error if not compiling device code
|
||||
#endif
|
||||
|
||||
// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
|
||||
// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if defined(_MSC_VER)
|
||||
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
|
||||
@ -61,7 +63,22 @@
|
||||
#define CUTLASS_PRAGMA_NO_UNROLL
|
||||
#endif
|
||||
|
||||
#define CUTLASS_ASSERT(x) assert(x)
|
||||
#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
|
||||
|
||||
// A small helper class to dump a type at compile time
|
||||
// Usage:: DumpType<Class>::Class
|
||||
template <typename T>
|
||||
struct DebugType {};
|
||||
|
||||
template <typename T>
|
||||
void DebugTypeFunc(T const& t) {
|
||||
T::t;
|
||||
}
|
||||
|
||||
// A small helper class to dump a compile time constant at compile time
|
||||
// Usage: DumpValue<Class::kConstant>::kConstant
|
||||
template <int Value>
|
||||
struct DebugValue {};
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
@ -29,9 +29,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/util/cutlass_math.h>
|
||||
#include <cutlass/vector.h>
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/util/cutlass_math.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -72,7 +72,7 @@ provides access to element at (d, h, w, c)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAlignment_>
|
||||
template <int alignment>
|
||||
struct StorageType {
|
||||
typedef uint64_t Type;
|
||||
};
|
||||
@ -108,9 +108,11 @@ struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
typedef Element_ Element;
|
||||
/// The number of elements.
|
||||
static int const kElements = kElements_;
|
||||
/// Alignment
|
||||
static int const kAlignment = kAlignment_;
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_DEVICE void clear() {
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
// Avoid element-wise access for sub 32b element type
|
||||
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
|
||||
uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
|
||||
@ -135,14 +137,10 @@ struct Fragment : public AlignedStruct<kAlignment_> {
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element& operator[](int i) {
|
||||
assert(i < kElements_);
|
||||
return reinterpret_cast<Element*>(storage)[i];
|
||||
}
|
||||
CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast<Element*>(storage)[i]; }
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE Element const& operator[](int i) const {
|
||||
assert(i < kElements_);
|
||||
CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
|
||||
return reinterpret_cast<Element const*>(storage)[i];
|
||||
}
|
||||
|
||||
@ -188,35 +186,35 @@ struct FragmentIterator {
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
|
||||
CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
|
||||
CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
|
||||
int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType& operator[](int i) {
|
||||
CUTLASS_HOST_DEVICE AccessType& operator[](int i) {
|
||||
return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element* pointer;
|
||||
@ -246,28 +244,28 @@ struct FragmentConstIterator {
|
||||
|
||||
/// Ctor.
|
||||
template <typename OtherFragment_>
|
||||
CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
|
||||
CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
|
||||
: pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
|
||||
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
|
||||
}
|
||||
/// Create from non-constant FragmentIterator
|
||||
CUTLASS_DEVICE FragmentConstIterator(
|
||||
CUTLASS_HOST_DEVICE FragmentConstIterator(
|
||||
FragmentIterator<Fragment_, Iterations_, AccessType_> const& rhs_)
|
||||
: pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
|
||||
int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
|
||||
return reinterpret_cast<AccessType const&>(pointer[imm]);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE AccessType const& operator[](int i) const {
|
||||
CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
|
||||
return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
/// The pointer.
|
||||
Element const* pointer;
|
||||
|
||||
@ -1,135 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines accessors for loading and storing fragments to memory efficiently.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
value.load(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentLoad<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
Load<Scalar_, kAccessSize, Memory_>::load(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
template <IteratorFragment::Kind kIteratorFragment,
|
||||
int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore {};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kWmmaMatrix,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
value.store(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
template <int kAccessSize,
|
||||
typename Scalar_,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride>
|
||||
struct FragmentStore<IteratorFragment::kScalar,
|
||||
kAccessSize,
|
||||
Scalar_,
|
||||
Memory_,
|
||||
FragmentElement_,
|
||||
kStride> {
|
||||
/// The input type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
Store<Scalar_, kAccessSize, Memory_>::store(value, pointer, offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} /// namespace cutlass
|
||||
@ -27,52 +27,59 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
template < typename ScalarAlphaBeta_,
|
||||
typename ScalarAccum_,
|
||||
bool fragMul2 = true /*number of element per fragment is multiple of 2*/
|
||||
>
|
||||
struct FragmentMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef Scalar_ ScalarA;
|
||||
/// The type for B.
|
||||
typedef Scalar_ ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef Scalar_ ScalarC;
|
||||
/// The type for alpha and beta
|
||||
typedef ScalarAlphaBeta_ ScalarAlphaBeta;
|
||||
/// The type for accumlator
|
||||
typedef ScalarAccum_ ScalarAccum;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0];
|
||||
d[j] = b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
d[j] += b[j * kReduction + k];
|
||||
}
|
||||
d[j] = a * ScalarAlphaBeta(d[j]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(Scalar_ a,
|
||||
CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a,
|
||||
FragmentB_ const& b,
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
for (int j = 0; j < FragmentCd_::kElements; ++j) {
|
||||
d[j] = a * b[j * kReduction + 0] + c[j];
|
||||
d[j] = b[j * kReduction + 0];
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d[j] += a * b[j * kReduction + k];
|
||||
d[j] += b[j * kReduction + k];
|
||||
}
|
||||
d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -80,15 +87,13 @@ struct FragmentMultiplyAdd {
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
template <>
|
||||
struct FragmentMultiplyAdd<half> {
|
||||
struct FragmentMultiplyAdd<half, half, true> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The type for B.
|
||||
typedef half ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef half ScalarC;
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The type for alpha and beta
|
||||
typedef half ScalarAlphaBeta;
|
||||
/// The type for accumlator
|
||||
typedef half ScalarAccum;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE FragmentMultiplyAdd() {}
|
||||
@ -97,17 +102,19 @@ struct FragmentMultiplyAdd<half> {
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The input.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
|
||||
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
@ -115,6 +122,7 @@ struct FragmentMultiplyAdd<half> {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
template <typename FragmentB_, typename FragmentCd_>
|
||||
CUTLASS_DEVICE void multiply_add(half a,
|
||||
@ -122,17 +130,19 @@ struct FragmentMultiplyAdd<half> {
|
||||
FragmentCd_ const& c,
|
||||
FragmentCd_& d) {
|
||||
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
// The inputs.
|
||||
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
|
||||
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
|
||||
// The output.
|
||||
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
|
||||
|
||||
// Assemble a half2 from a.
|
||||
__half2 const a_half2 = __half2half2(a);
|
||||
|
||||
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
|
||||
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
|
||||
|
||||
for (int k = 1; k < kReduction; ++k) {
|
||||
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -39,11 +39,12 @@ struct ClearAccumulators {
|
||||
/// The shared storage.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ClearAccumulators() {}
|
||||
|
||||
/// Clear the fragment.
|
||||
template <typename Fragment_>
|
||||
CUTLASS_DEVICE void clear(Fragment_& fragment) {
|
||||
|
||||
@ -27,13 +27,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -41,10 +41,10 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
/// The tile size for threadblock-level GEMM (K-by-N-by-M).
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
@ -62,7 +62,7 @@ struct DgemmConfig
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, double, double, double>,
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, double, double, double>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
@ -82,7 +82,14 @@ struct DgemmConfig
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
false,
|
||||
/// kLaunchBounds
|
||||
false
|
||||
>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -91,12 +98,12 @@ template <
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
/// The tile size for threadblock-level GEMM (K-by-N-by-M)
|
||||
typename OutputTile_ = Shape<8, 64, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<double>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of doubles loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of doubles loaded in one LDG for B.
|
||||
@ -105,7 +112,7 @@ template <
|
||||
typename Index_ = int,
|
||||
/// The DGEMM config.
|
||||
typename GemmConfig_ =
|
||||
DgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
DgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
|
||||
83
cutlass/gemm/fp16_sgemm_multiply_add.h
Normal file
@ -0,0 +1,83 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template implementing matrix multiply-add operations on fragments.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename ThreadGemmShape_,
|
||||
typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The shape of a thread-leveel matrix multiply accumulate.
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A. specialized to half
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ScalarA, AccumulatorsPerThread::kW> FragmentA;
|
||||
/// The type for B. specialized to half
|
||||
typedef half ScalarB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ScalarB, AccumulatorsPerThread::kH> FragmentB;
|
||||
/// The type for C and D. specialized to float
|
||||
typedef float ScalarC;
|
||||
/// The accumulators.
|
||||
typedef Fragment<ScalarC, AccumulatorsPerThread::kH * AccumulatorsPerThread::kW, 16> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ThreadMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b + c.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
152
cutlass/gemm/fp16_sgemm_traits.h
Normal file
@ -0,0 +1,152 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies structural properties of single-precision GEMM where any number of the input/output
|
||||
could be fp16 or fp32. The accumulator type stays in fp32
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/fp16_sgemm_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The type for A
|
||||
typename ScalarA_,
|
||||
/// The type for B
|
||||
typename ScalarB_,
|
||||
/// The type for C
|
||||
typename ScalarC_,
|
||||
/// The type for D
|
||||
typename ScalarD_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct Fp16SgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
ScalarA_,
|
||||
/// The scalar type for B.
|
||||
ScalarB_,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The type for A
|
||||
typename ScalarA_ = half,
|
||||
/// The type for B
|
||||
typename ScalarB_ = half,
|
||||
/// The type for C
|
||||
typename ScalarC_ = half,
|
||||
/// The type for D
|
||||
typename ScalarD_ = half,
|
||||
/// the Type for alpha and beta,
|
||||
typename Scalar_ = half,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<Scalar_, FragmentMultiplyAdd<Scalar_, float/*accumulator type*/> >,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
Fp16SgemmConfig<OutputTile_,
|
||||
ThreadGemmShape_,
|
||||
ScalarA_,
|
||||
ScalarB_,
|
||||
ScalarC_,
|
||||
ScalarD_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct Fp16SgemmSgemmTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -31,16 +31,17 @@
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM kernel with launch bounds specified
|
||||
template <typename Gemm_>
|
||||
__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
|
||||
__global__ __launch_bounds__(Gemm_::kThreads)
|
||||
void gemm_kernel(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
@ -52,28 +53,37 @@ __global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, typename Index_ = int>
|
||||
struct GemmDesc {
|
||||
/// The dimensions of the GEMM.
|
||||
Index_ m, n, k;
|
||||
/// The alpha/beta scaling values.
|
||||
Scalar_ alpha, beta;
|
||||
/// The source matrix A.
|
||||
void const* d_a;
|
||||
/// The stride for A.
|
||||
Index_ lda;
|
||||
/// The source matrix B.
|
||||
void const* d_b;
|
||||
/// The stride for B.
|
||||
Index_ ldb;
|
||||
/// The source matrix C.
|
||||
void const* d_c;
|
||||
/// The stride for C.
|
||||
Index_ ldc;
|
||||
/// The destination matrix D.
|
||||
void* d_d;
|
||||
/// The stride for D.
|
||||
Index_ ldd;
|
||||
/// GEMM kernel without launch bounds specified
|
||||
template <typename Gemm_>
|
||||
__global__ /* __launch_bounds__(Gemm_::kThreads) */
|
||||
void gemm_kernel_nolb(typename Gemm_::Params params) {
|
||||
// Declare shared memory.
|
||||
__shared__ typename Gemm_::SharedStorage shared_storage;
|
||||
|
||||
// Construct the GEMM object.
|
||||
Gemm_ gemm(params, shared_storage);
|
||||
// Run GEMM.
|
||||
gemm.multiply_add();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for launching the GEMM kernel with or without launch bounds
|
||||
template <typename Gemm, bool WithLaunchBounds>
|
||||
struct Launch {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
gemm_kernel<Gemm><<< grid, block, 0, stream >>>(params);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for launching the GEMM kernel with or without launch bounds
|
||||
template <typename Gemm>
|
||||
struct Launch<Gemm, false> {
|
||||
Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
|
||||
gemm_kernel_nolb<Gemm><<< grid, block, 0, stream >>>(params);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -100,86 +110,52 @@ struct Gemm {
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
|
||||
/// Define the mainloop iteration size
|
||||
typedef typename Traits::MultiplyAdd MultiplyAdd;
|
||||
|
||||
/// The number of threads.
|
||||
static int const kThreads = Traits::GemmConfig::kThreads;
|
||||
|
||||
/// The params.
|
||||
struct Params : public Traits::Params {
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
ScalarEpilogue alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
ScalarEpilogue beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
GemmDesc<ScalarEpilogue, Index> desc;
|
||||
desc.m = m;
|
||||
desc.n = n;
|
||||
desc.k = k;
|
||||
desc.alpha = alpha;
|
||||
desc.beta = beta;
|
||||
desc.d_a = reinterpret_cast<void const*>(d_a);
|
||||
desc.lda = lda;
|
||||
desc.d_b = reinterpret_cast<void const*>(d_b);
|
||||
desc.ldb = ldb;
|
||||
desc.d_c = reinterpret_cast<void const*>(d_c);
|
||||
desc.ldc = ldc;
|
||||
desc.d_d = reinterpret_cast<void*>(d_d);
|
||||
desc.ldd = ldd;
|
||||
return Traits::Params::initialize(desc);
|
||||
}
|
||||
};
|
||||
// Number of warp-level multiply-accumulate steps executed by each warp.
|
||||
static Index const kWarpGemmSteps =
|
||||
Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
|
||||
|
||||
/// Use the params object defined in traits
|
||||
typedef typename Traits::Params Params;
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
||||
/// Support for NVRTC
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(Params const& params,
|
||||
cudaStream_t stream = cudaStreamDefault) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void const* params_ = reinterpret_cast<void const*>(¶ms);
|
||||
Launch<This_, GemmTraits_::GemmConfig::kLaunchBounds>(
|
||||
params, params.grid, params.block, stream);
|
||||
|
||||
return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>),
|
||||
grid,
|
||||
block,
|
||||
const_cast<void**>(¶ms_),
|
||||
0,
|
||||
stream);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
/// Launch the kernel.
|
||||
static __host__ cudaError_t launch(CUfunction kernel,
|
||||
Params const& params,
|
||||
CUstream stream = CU_STREAM_LEGACY) {
|
||||
// Setup the grid.
|
||||
dim3 grid;
|
||||
grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
|
||||
grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
|
||||
|
||||
// The number of threads.
|
||||
dim3 block;
|
||||
block.x = kThreads;
|
||||
|
||||
// Launch the kernel.
|
||||
void* params_[] = {const_cast<void*>(reinterpret_cast<void const*>(¶ms))};
|
||||
|
||||
// return cudaLaunchKernel(reinterpret_cast<void*>(&gemm_kernel<This_>), grid, block,
|
||||
// const_cast<void**>(¶ms_), 0, stream);
|
||||
CUresult result = cuLaunchKernel(
|
||||
kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
|
||||
kernel,
|
||||
params.grid.x, params.grid.y, params.grid.z,
|
||||
params.block.x, params.block.y, params.block.z,
|
||||
0, stream, params_, 0);
|
||||
|
||||
if (result != CUDA_SUCCESS) {
|
||||
return cudaErrorLaunchFailure;
|
||||
@ -189,39 +165,41 @@ struct Gemm {
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
|
||||
: params(params_), shared_storage(shared_storage_) {}
|
||||
|
||||
/// Consume a single iteration of the loop.
|
||||
template <bool kIsLastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
|
||||
typename Traits::SharedLoadStream& shared_load_stream,
|
||||
typename Traits::MultiplyAdd::Accumulators& accumulators,
|
||||
/// Computes a warp-level GEMM on data held in shared memory
|
||||
template <bool Residue, bool LastIteration>
|
||||
CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
|
||||
typename Traits::SharedStream& shared_load_stream,
|
||||
typename MultiplyAdd::Accumulators& accumulators,
|
||||
Index outer_k) {
|
||||
// If that's the last "load iteration" update the predicates.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.move_to_residue<false>(outer_k);
|
||||
// If residue portion and not calculating residue in prolog, update residue predicates now.
|
||||
if (Residue && outer_k <= Traits::OutputTile::kD) {
|
||||
global_to_shared_stream.residue(outer_k);
|
||||
}
|
||||
|
||||
// Load data for the next iteration of the main loop.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.copy();
|
||||
// Load data for the next iteration of the main loop (unless it's the last iteration).
|
||||
if (!LastIteration) {
|
||||
global_to_shared_stream.copy();
|
||||
}
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int step = 0; step < kUnrollingSteps - 1; ++step) {
|
||||
for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
|
||||
// Trigger the copy from shared memory for the next A/B values.
|
||||
shared_load_stream.copy(step + 1);
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(step);
|
||||
|
||||
MultiplyAdd multiply_add;
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
|
||||
shared_load_stream.fragment_b(step),
|
||||
accumulators,
|
||||
@ -232,28 +210,25 @@ struct Gemm {
|
||||
Traits::shared_load_fence(true);
|
||||
|
||||
// Commit the data in shared memory for A/B.
|
||||
if (!kIsLastIteration) {
|
||||
global_stream.commit();
|
||||
if (!LastIteration) {
|
||||
global_to_shared_stream.commit();
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(true);
|
||||
|
||||
// Trigger the loads for the next iteration (if needed).
|
||||
if (!kIsLastIteration) {
|
||||
if (!LastIteration) {
|
||||
// Move to the next stage for the load (if it makes sense).
|
||||
shared_load_stream.inc_stage();
|
||||
// Trigger the copy from shared memory for the next loop iteration.
|
||||
shared_load_stream.copy(0);
|
||||
}
|
||||
|
||||
// Make sure the values are available for the current iteration to do the multiply-add.
|
||||
shared_load_stream.commit(kUnrollingSteps - 1);
|
||||
shared_load_stream.commit(kWarpGemmSteps - 1);
|
||||
|
||||
// Do the math on the fragments of the current iteration.
|
||||
typename Traits::MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
|
||||
shared_load_stream.fragment_b(kUnrollingSteps - 1),
|
||||
MultiplyAdd multiply_add;
|
||||
multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
|
||||
shared_load_stream.fragment_b(kWarpGemmSteps - 1),
|
||||
accumulators,
|
||||
accumulators);
|
||||
}
|
||||
@ -262,76 +237,112 @@ struct Gemm {
|
||||
CUTLASS_DEVICE void multiply_add() {
|
||||
// Swizzle the IDs of the block (to enable better cache behavior).
|
||||
typename Traits::BlockSwizzle block_swizzle;
|
||||
dim3 block = block_swizzle.swizzle();
|
||||
|
||||
// Scale the id.
|
||||
block.x *= Traits::OutputTile::kW;
|
||||
block.y *= Traits::OutputTile::kH;
|
||||
Coord<3> threadblock_offset =
|
||||
block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::OutputTile>());
|
||||
|
||||
// We may want to use shared memory to clear the registers.
|
||||
typedef typename Traits::ClearAccumulators ClearAccumulators;
|
||||
|
||||
// The streams to read A/B from global memory to shared memory.
|
||||
typename Traits::GlobalLoadStream global_stream(params, shared_storage, block);
|
||||
typename Traits::GlobalLoadStream global_to_shared_stream(
|
||||
params.global_to_shared_stream,
|
||||
shared_storage.main_loop.global_to_shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference(),
|
||||
params.problem_size.knm(),
|
||||
threadblock_offset);
|
||||
|
||||
// update A and B pointer offset based on batch_id and batch_stride_offset
|
||||
//global_to_shared_stream.add_pointer_offset(block_swizzle.get_batch_id(), params.batch_stride_A, params.batch_stride_B);
|
||||
global_to_shared_stream += make_Coord(block_swizzle.get_batch_id(), 0, 0);
|
||||
|
||||
// Create the accumulator clear.
|
||||
ClearAccumulators clear(shared_storage.main_loop.clear);
|
||||
ClearAccumulators clear;
|
||||
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(Traits::OutputTile::kD);
|
||||
|
||||
// If we do not have enough steps in the main loop, trigger the residue code.
|
||||
global_stream.move_to_residue<true>(params.k);
|
||||
// Deal with residue in prolog.
|
||||
global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
|
||||
|
||||
// Fetch the fragments for A and B from global memory.
|
||||
global_stream.copy();
|
||||
global_to_shared_stream.copy();
|
||||
|
||||
// Copy the elements to shared memory (after transformation if needed).
|
||||
global_stream.commit();
|
||||
global_to_shared_stream.commit();
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
Traits::shared_store_fence(false);
|
||||
|
||||
// Rollback to the beginning of the GEMM-K dimension. It may have no impact.
|
||||
global_stream.rollback();
|
||||
|
||||
// The unrolling steps for the main loop.
|
||||
int const kUnrollingSteps =
|
||||
Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
|
||||
|
||||
// Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
|
||||
static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
|
||||
// Rollback to the beginning of the first tile (if residue exists).
|
||||
global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
|
||||
|
||||
// The stream of data from shared memory to fragments.
|
||||
typename Traits::SharedLoadStream shared_load_stream(params, shared_storage);
|
||||
typename Traits::SharedStream shared_load_stream(
|
||||
params.shared_stream,
|
||||
shared_storage.main_loop.threadblock_tile.reference());
|
||||
|
||||
// Trigger the copy from shared memory for the 1st stream.
|
||||
shared_load_stream.copy(0);
|
||||
|
||||
// Allocate the accumulators.
|
||||
typename Traits::MultiplyAdd::Accumulators accumulators;
|
||||
typename MultiplyAdd::Accumulators accumulators;
|
||||
|
||||
// Clear the accumulators.
|
||||
clear.clear(accumulators);
|
||||
|
||||
// The loop index.
|
||||
Index outer_k = params.k - kUnroll;
|
||||
// Initial index
|
||||
Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
|
||||
|
||||
// Enter the main loop and iterate.
|
||||
for (; outer_k > 0; outer_k -= kUnroll) {
|
||||
consume_tile<false>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
}
|
||||
// Check if we are computing residue in prolog or not.
|
||||
if (Traits::GemmConfig::kResidueInProlog) {
|
||||
|
||||
// Residual loop.
|
||||
for (; outer_k > -kUnroll; outer_k -= kUnroll) {
|
||||
consume_tile<true>(global_stream, shared_load_stream, accumulators, outer_k);
|
||||
// Execute all mainloop iterations but the last one.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
|
||||
// Don't load data for the last "residue" portion since we've already computed the residue.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, true>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
} else {
|
||||
// When kResidueSeparate = true, execute all mainloop iterations but the last two without any
|
||||
// consideration for K-residue or predicate updates. This improves the steady state of some
|
||||
// kernels.
|
||||
if (Traits::GemmConfig::kResidueSeparate) {
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<false, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Execute remaining tiles with K-residue predicate updates enabled.
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
|
||||
consume_tile<true, false>(
|
||||
global_to_shared_stream, shared_load_stream, accumulators, outer_k);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Epilogue.
|
||||
typedef typename Traits::Epilogue Epilogue;
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n);
|
||||
epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators);
|
||||
Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
|
||||
epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
|
||||
}
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The params.
|
||||
Params const& params;
|
||||
/// The shared storage.
|
||||
|
||||
145
cutlass/gemm/gemm_config.h
Normal file
@ -0,0 +1,145 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines properties of GEMM computation that impose some constraints on caller.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The scalar type for A.
|
||||
typename ScalarA_,
|
||||
/// The scalar type for B.
|
||||
typename ScalarB_,
|
||||
/// The scalar type for C.
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D.
|
||||
typename ScalarD_,
|
||||
/// The threadblock tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math.
|
||||
typename MultiplyAdd_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
int kScalarsPerStsA_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
int kScalarsPerStsB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int kScalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
bool kResidueSeparate_ = false,
|
||||
/// Is residue performed in prologue?
|
||||
bool kResidueInProlog_ = false,
|
||||
/// If true, kernel is launched with CUDA launch bounds specified
|
||||
bool kLaunchBounds_ = true>
|
||||
struct GemmConfig {
|
||||
//
|
||||
/// The scalar for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// The tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The functor to do D = A*B + C.
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
/// The shape of the instruction.
|
||||
typedef typename MultiplyAdd::InstructionShape InstructionShape;
|
||||
/// The shape of warp-level GEMM
|
||||
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
|
||||
/// The accumulators.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// The number of warps.
|
||||
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
|
||||
/// The default warp size (32 threads per warp).
|
||||
static int const kWarpSize = cutlass::kWarpSize;
|
||||
/// The numnber of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
|
||||
static int const kScalarsPerStsA = kScalarsPerStsA_;
|
||||
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
|
||||
static int const kScalarsPerStsB = kScalarsPerStsB_;
|
||||
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
|
||||
|
||||
/// The number of scalars per LDG for C.
|
||||
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
|
||||
|
||||
/// The number of scalars per STS/LDS/STG for D.
|
||||
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
|
||||
static int const kScalarsPerStsD = kScalarsPerStsD_;
|
||||
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
|
||||
|
||||
/// The number of accumulators that are going to be fed from one LDS A/B.
|
||||
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
|
||||
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// If true, mainloop is instantiated twice. The first instantiation contains no predicate
|
||||
// updates and is more efficient for some kernels. If false, only a single mainloop is
|
||||
// instantaited.
|
||||
static bool const kResidueSeparate = kResidueSeparate_;
|
||||
|
||||
/// If true, residue is computed in the prologue.
|
||||
static bool const kResidueInProlog = kResidueInProlog_;
|
||||
|
||||
/// If true, kernel is launched with launch bounds specified
|
||||
static bool const kLaunchBounds = kLaunchBounds_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
203
cutlass/gemm/gemm_coord.h
Normal file
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GemmCoord is a structure derived from Coord<4> that specifies a location within the
|
||||
coordinate system of a GEMM problem.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GemmCoord is a structure derived from Coord<4> that specifies a location within the
|
||||
/// coordinate space of a GEMM problem.
|
||||
struct GemmCoord : public Coord<4, int> {
|
||||
|
||||
/// Integer-valued index
|
||||
typedef int Index;
|
||||
|
||||
/// Base type is a Coord of rank=4
|
||||
typedef Coord<4, Index> Base;
|
||||
|
||||
/// GEMM K dimension - inner dimension of the GEMM problem
|
||||
static int const kK = 0;
|
||||
|
||||
/// GEMM N dimension - columns of the output C matrix
|
||||
static int const kN = 1;
|
||||
|
||||
/// GEMM M dimension - rows of the output C matrix
|
||||
static int const kM = 2;
|
||||
|
||||
/// Batch dimension - for generalizing to larger problems
|
||||
static int const kBatch = 3;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord() { }
|
||||
|
||||
/// Constructs from Coord<3> and a batch
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { }
|
||||
|
||||
/// Constructs from Coord<4>
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Coord<4, Index> const &coord): Base(coord) { }
|
||||
|
||||
/// Constructs from an array of coordinate elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Index coord[4]): Base(coord) { }
|
||||
|
||||
/// Helper to construct from a K, N, M, batch variables
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord(Index k, Index n, Index m, Index batch = 0): Base(make_Coord(k, n, m, batch)) { }
|
||||
|
||||
/// Returns the GEMM M coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & m() const { return this->at(kM); }
|
||||
|
||||
/// Returns reference to the GEMM M coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & m() { return this->at(kM); }
|
||||
|
||||
/// Returns the GEMM N coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & n() const { return this->at(kN); }
|
||||
|
||||
/// Returns reference to the GEMM N coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & n() { return this->at(kN); }
|
||||
|
||||
/// Returns the GEMM K coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & k() const { return this->at(kK); }
|
||||
|
||||
/// Returns reference to the GEMM K coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & k() { return this->at(kK); }
|
||||
|
||||
/// Returns the GEMM batch coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & batch() const { return this->at(kBatch); }
|
||||
|
||||
/// Returns reference to the GEMM batch coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & batch() { return this->at(kBatch); }
|
||||
|
||||
/// Obtains a Coord<3> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> knm() const {
|
||||
return make_Coord(k(), n(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> nm() const {
|
||||
return make_Coord(n(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> km() const {
|
||||
return make_Coord(k(), m());
|
||||
}
|
||||
|
||||
/// Obtains a Coord<2> from GemmCoord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> kn() const {
|
||||
return make_Coord(k(), n());
|
||||
}
|
||||
|
||||
//
|
||||
// Coord operators
|
||||
//
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator+(Base const& b) const {
|
||||
return GemmCoord(Base::operator+(b));
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator-(Base const& b) const {
|
||||
return GemmCoord(Base::operator-(b));
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator*(Base const& b) const {
|
||||
return GemmCoord(Base::operator*(b));
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord operator/(Base const& b) const {
|
||||
return GemmCoord(Base::operator/(b));
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator+=(Base const& b) {
|
||||
Base::operator+=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator-=(Base const& b) {
|
||||
Base::operator-=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator*=(Base const& b) {
|
||||
Base::operator*=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmCoord& operator/=(Base const& b) {
|
||||
Base::operator/=(b);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
205
cutlass/gemm/gemm_desc.h
Normal file
@ -0,0 +1,205 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements a software-pipelined efficient GEMM.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
/// GEMM problem description
|
||||
template <
|
||||
/// Source accumulator matrix type
|
||||
typename AType_,
|
||||
/// Destination accumulator type
|
||||
typename BType_,
|
||||
/// Source accumulator matrix type
|
||||
typename CType_,
|
||||
/// Destination accumulator type
|
||||
typename DType_,
|
||||
/// Scalar type for alpha and beta
|
||||
typename SType_,
|
||||
/// Index type for dimensions and strides
|
||||
typename Index_ = int
|
||||
> struct GemmDesc {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Index type for dimensions and strides
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Source accumulator matrix type
|
||||
typedef AType_ AType;
|
||||
|
||||
/// Tensor reference to A operand
|
||||
typedef TensorRef<AType const, 2> TensorRefA;
|
||||
|
||||
/// Destination accumulator type
|
||||
typedef BType_ BType;
|
||||
|
||||
/// Tensor reference to B operand
|
||||
typedef TensorRef<BType const, 2> TensorRefB;
|
||||
|
||||
/// Source accumulator matrix type
|
||||
typedef CType_ CType;
|
||||
|
||||
/// Tensor reference to C operand
|
||||
typedef TensorRef<CType const, 2> TensorRefC;
|
||||
|
||||
/// Destination accumulator type
|
||||
typedef DType_ DType;
|
||||
|
||||
/// Tensor reference to D operand
|
||||
typedef TensorRef<DType, 2> TensorRefD;
|
||||
|
||||
/// Scalar type for alpha and beta
|
||||
typedef SType_ SType;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The dimensions of the GEMM.
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// The alpha scaling values.
|
||||
SType alpha;
|
||||
|
||||
/// The source matrix A.
|
||||
TensorRefA A;
|
||||
|
||||
/// batch stride for A operand
|
||||
long long batch_stride_A;
|
||||
|
||||
/// The source matrix B.
|
||||
TensorRefB B;
|
||||
|
||||
/// batch stride for B operand
|
||||
long long batch_stride_B;
|
||||
|
||||
/// The beta scaling values.
|
||||
SType beta;
|
||||
|
||||
/// The source matrix C.
|
||||
TensorRefC C;
|
||||
|
||||
/// batch stride for C operand
|
||||
long long batch_stride_C;
|
||||
|
||||
/// The destination matrix D.
|
||||
TensorRefD D;
|
||||
|
||||
/// batch stride for D operand
|
||||
long long batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {}
|
||||
|
||||
/// Constructor for basic GEMM with batch count = 1
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(Coord<3> _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
TensorRefB const &_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
TensorRefD const &_D
|
||||
):
|
||||
problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(0),
|
||||
B(_B),
|
||||
batch_stride_B(0),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(0),
|
||||
D(_D),
|
||||
batch_stride_D(0) {}
|
||||
|
||||
/// Constructor for basic GEMM with batch count = 1
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(GemmCoord _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
TensorRefB const &_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
TensorRefD const &_D
|
||||
):
|
||||
problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(0),
|
||||
B(_B),
|
||||
batch_stride_B(0),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(0),
|
||||
D(_D),
|
||||
batch_stride_D(0) {
|
||||
|
||||
assert(_problem_size.batch() == 1);
|
||||
}
|
||||
|
||||
/// Constructor for strided batch GEMM GEMM
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmDesc(GemmCoord _problem_size,
|
||||
SType _alpha,
|
||||
TensorRefA const &_A,
|
||||
long long _batch_stride_A,
|
||||
TensorRefB const &_B,
|
||||
long long _batch_stride_B,
|
||||
SType _beta,
|
||||
TensorRefC const &_C,
|
||||
long long _batch_stride_C,
|
||||
TensorRefD const &_D,
|
||||
long long _batch_stride_D
|
||||
):
|
||||
problem_size(_problem_size),
|
||||
alpha(_alpha),
|
||||
A(_A),
|
||||
batch_stride_A(_batch_stride_A),
|
||||
B(_B),
|
||||
batch_stride_B(_batch_stride_B),
|
||||
beta(_beta),
|
||||
C(_C),
|
||||
batch_stride_C(_batch_stride_C),
|
||||
D(_D),
|
||||
batch_stride_D(_batch_stride_D) {}
|
||||
};
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -29,26 +29,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE bool is_zero(T x) {
|
||||
return x == T(0);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmEpilogueTraits_>
|
||||
struct GemmEpilogue {
|
||||
/// The traits class.
|
||||
@ -85,9 +74,7 @@ struct GemmEpilogue {
|
||||
/// The shared store transformer for D.
|
||||
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
|
||||
/// The iterator to load D in shared memory.
|
||||
typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD;
|
||||
/// The shared load transformer for D.
|
||||
typedef Copy<typename SharedLoadIteratorD::Fragment> SharedLoadTransformerD;
|
||||
typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
|
||||
|
||||
/// The index.
|
||||
typedef typename Traits::Index Index;
|
||||
@ -100,33 +87,28 @@ struct GemmEpilogue {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
|
||||
SharedStorage& shared_storage_,
|
||||
Index m_,
|
||||
Index n_)
|
||||
: params(params_), shared_storage(shared_storage_), m(m_), n(n_) {}
|
||||
Coord<3> const& _problem_size)
|
||||
: params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {}
|
||||
|
||||
/// Execute the epilogue.
|
||||
CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) {
|
||||
if (is_zero(params.functor.beta)) {
|
||||
epilogue_with_or_without_beta<true>(block, accumulators);
|
||||
CUTLASS_DEVICE void epilogue(Accumulators& accumulators,
|
||||
Coord<3> const& block = make_Coord(0, 0, 0),
|
||||
int batch_id = 0) {
|
||||
if (functor.source_required()) {
|
||||
epilogue_with_or_without_beta<true>(accumulators, block, batch_id);
|
||||
} else {
|
||||
epilogue_with_or_without_beta<false>(block, accumulators);
|
||||
epilogue_with_or_without_beta<false>(accumulators, block, batch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool kBetaIsZero_>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
|
||||
Accumulators& accumulators) {
|
||||
|
||||
// The problem size.
|
||||
Coord<3> const bounds = cutlass::make_Coord(0, n, m);
|
||||
|
||||
// The functor.
|
||||
Functor functor(params.functor);
|
||||
template <bool kSourceRequired>
|
||||
CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators,
|
||||
Coord<3> const& block,
|
||||
int batch_id) {
|
||||
// The C fragment.
|
||||
typename GlobalLoadIteratorC::Fragment fragment_c;
|
||||
// The transformed C fragment.
|
||||
typename GlobalTransformerC::OutputFragment transformed_c;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
// Compute pointer and predicate offsets for C and D global iterators.
|
||||
@ -136,6 +118,7 @@ struct GemmEpilogue {
|
||||
Iterations::kW +
|
||||
params.stride_h) *
|
||||
h;
|
||||
|
||||
int const predicate_offset =
|
||||
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
|
||||
params.iterator_d.predicate_inc_advance) *
|
||||
@ -145,32 +128,40 @@ struct GemmEpilogue {
|
||||
|
||||
// The iterator to load the elements of the C matrix.
|
||||
GlobalLoadIteratorC global_load_iterator(
|
||||
params.iterator_c, bounds, block, pointer_offset, predicate_offset);
|
||||
params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update C pointer offset based on batch_id and batch_stride_offset
|
||||
//global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_c);
|
||||
global_load_iterator += make_Coord(batch_id, 0, 0);
|
||||
|
||||
// The transformer for C.
|
||||
GlobalTransformerC transformer_c;
|
||||
// The transformer for D.
|
||||
GlobalTransformerD transformer_d;
|
||||
// The iterator to store into the D matrix.
|
||||
GlobalStoreIteratorD global_store_iterator(
|
||||
params.iterator_d, bounds, block, pointer_offset, predicate_offset);
|
||||
params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
|
||||
|
||||
// update D pointer offset based on batch_id and batch_stride_offset
|
||||
//global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_d);
|
||||
global_store_iterator += make_Coord(batch_id, 0, 0);
|
||||
|
||||
// The transformer to transform before storing to shared memory.
|
||||
SharedStoreTransformerD shared_store_transformer;
|
||||
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
|
||||
|
||||
// The iterator to store to shared memory.
|
||||
SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
|
||||
shared_storage.shared_stream.store);
|
||||
SharedStoreIteratorD shared_store_iterator(
|
||||
params.shared_store_iterator_d,
|
||||
reinterpret_cast<typename SharedStoreIteratorD::Scalar*>(shared_storage.data()));
|
||||
|
||||
// The iterator to load from shared memory. TODO: Use a stream.
|
||||
SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
|
||||
shared_storage.shared_stream.load);
|
||||
SharedLoadStreamD shared_load_stream(
|
||||
params.shared_load_stream_d,
|
||||
reinterpret_cast<typename SharedLoadStreamD::Scalar*>(shared_storage.data()));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
// Load the C matrix into fragment.
|
||||
if (!kBetaIsZero_) {
|
||||
iterator_load(global_load_iterator, fragment_c);
|
||||
if (kSourceRequired) {
|
||||
global_load_iterator.load_post_increment(fragment_c);
|
||||
}
|
||||
|
||||
// Make sure we can write to shared memory.
|
||||
@ -180,33 +171,33 @@ struct GemmEpilogue {
|
||||
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
|
||||
|
||||
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
|
||||
shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
|
||||
shared_store_iterator.store_post_increment(shared_store_transformed_d);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
shared_store_fence();
|
||||
|
||||
// Copy the accumulators back to registers from shared memory.
|
||||
typename SharedLoadIteratorD::Fragment fetched_d;
|
||||
shared_iterator_load(shared_load_iterator, fetched_d);
|
||||
shared_load_stream.copy();
|
||||
shared_load_stream.commit();
|
||||
|
||||
// Do the math.
|
||||
typename GlobalTransformerD::InputFragment fragment_d;
|
||||
|
||||
if (kBetaIsZero_) {
|
||||
functor.evaluate(fetched_d, fragment_d);
|
||||
} else {
|
||||
if (kSourceRequired) {
|
||||
// Transform C fragment.
|
||||
transformer_c.transform(fragment_c, transformed_c);
|
||||
// Do the math.
|
||||
functor.evaluate(fetched_d, transformed_c, fragment_d);
|
||||
functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
|
||||
} else {
|
||||
functor.evaluate(shared_load_stream.fragment(), fragment_d);
|
||||
}
|
||||
|
||||
// Transform D fragment.
|
||||
typename GlobalTransformerD::OutputFragment transformed_d;
|
||||
transformer_d.transform(fragment_d, transformed_d);
|
||||
typename GlobalTransformerD::OutputFragment global_transformed_d;
|
||||
transformer_d.transform(fragment_d, global_transformed_d);
|
||||
|
||||
// Copy the results to global memory.
|
||||
iterator_store(global_store_iterator, transformed_d);
|
||||
global_store_iterator.store_post_increment(global_transformed_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +213,9 @@ struct GemmEpilogue {
|
||||
/// The shared storage.
|
||||
SharedStorage& shared_storage;
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n;
|
||||
Coord<3> problem_size;
|
||||
// The functor.
|
||||
Functor functor;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -27,13 +27,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -57,8 +57,8 @@ template <
|
||||
typename SharedStoreIteratorD_,
|
||||
/// The shared store transformer for D.
|
||||
typename SharedStoreTransformerD_,
|
||||
/// The iterator to load D from shared memory.
|
||||
typename SharedLoadIteratorD_,
|
||||
/// The stream to load D from shared memory.
|
||||
typename SharedLoadStreamD_,
|
||||
/// The number of iterations in the epilogue.
|
||||
typename Iterations_,
|
||||
/// The iterations strides.
|
||||
@ -86,8 +86,8 @@ struct GemmEpilogueTraits {
|
||||
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
|
||||
/// The shared store transformer for D.
|
||||
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
|
||||
/// The iterator to store D in shared memory.
|
||||
typedef SharedLoadIteratorD_ SharedLoadIteratorD;
|
||||
/// The stream to store D in shared memory.
|
||||
typedef SharedLoadStreamD_ SharedLoadStreamD;
|
||||
/// typedef typename GemmConfig::EpilogueIterations Iterations;
|
||||
typedef Iterations_ Iterations;
|
||||
/// The iterations strides.
|
||||
@ -118,14 +118,15 @@ struct GemmEpilogueTraits {
|
||||
typename GlobalStoreIteratorD::Params iterator_d;
|
||||
/// The params for the D shared store iterator.
|
||||
typename SharedStoreIteratorD::Params shared_store_iterator_d;
|
||||
/// The params for the D shared load iterator.
|
||||
typename SharedLoadIteratorD::Params shared_load_iterator_d;
|
||||
/// The params for the D shared load stream.
|
||||
typename SharedLoadStreamD::Params shared_load_stream_d;
|
||||
/// The functor params.
|
||||
typename Functor::Params functor;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
// The parameters for the functor.
|
||||
int error_code = functor.initialize(desc);
|
||||
if (error_code) {
|
||||
@ -133,20 +134,27 @@ struct GemmEpilogueTraits {
|
||||
}
|
||||
|
||||
// At the end of the H iteration, we jump over a number of columns.
|
||||
this->stride_h = desc.ldd * Delta::kH;
|
||||
this->stride_h = desc.D.leading_dim() * Delta::kH;
|
||||
// Nothing to do here.
|
||||
this->stride_w = 0;
|
||||
|
||||
// Setup the params for the global memory iterator for C.
|
||||
error_code = iterator_c.initialize(
|
||||
reinterpret_cast<ScalarC const*>(desc.d_c), desc.ldc, desc.n, stride_w, Delta::kW);
|
||||
error_code = iterator_c.initialize(desc.C.data(),
|
||||
desc.batch_stride_C,
|
||||
desc.C.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Setup the params for the global memory iterator for D.
|
||||
return iterator_d.initialize(
|
||||
reinterpret_cast<ScalarD*>(desc.d_d), desc.ldd, desc.n, stride_w, Delta::kW);
|
||||
return iterator_d.initialize(desc.D.data(),
|
||||
desc.batch_stride_D,
|
||||
desc.D.leading_dim(),
|
||||
desc.problem_size[1],
|
||||
stride_w,
|
||||
Delta::kW);
|
||||
}
|
||||
};
|
||||
|
||||
@ -155,13 +163,20 @@ struct GemmEpilogueTraits {
|
||||
// The storage for the store iterator.
|
||||
typename SharedStoreIteratorD::SharedStorage store;
|
||||
// The storage for the store iterator.
|
||||
typename SharedLoadIteratorD::SharedStorage load;
|
||||
typename SharedLoadStreamD::SharedStorage load;
|
||||
};
|
||||
|
||||
/// The shared memory to swizzle the data in the epilogue.
|
||||
struct SharedStorage {
|
||||
// The storage for the shared stream D.
|
||||
StreamSharedStorage shared_stream;
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ScalarD* data() { return reinterpret_cast<ScalarD*>(&shared_stream.load); }
|
||||
};
|
||||
};
|
||||
|
||||
@ -192,7 +207,10 @@ struct GemmEpilogueTraitsHelper {
|
||||
/// The traits class to build the iterator to store to shared memory for D.
|
||||
typedef GemmSharedStoreTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// typename Functor::Scalar,
|
||||
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
|
||||
// In this case Functor::ScalarAccum is needed
|
||||
typename Functor::ScalarAccum,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
@ -221,7 +239,10 @@ struct GemmEpilogueTraitsHelper {
|
||||
/// The traits class to build the iterator to load from shared memory for D.
|
||||
typedef GemmSharedLoadTileDTraits<
|
||||
// The pointer is float.
|
||||
typename Functor::Scalar,
|
||||
// typename Functor::Scalar,
|
||||
// Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
|
||||
// In this case Functor::ScalarAccum is needed
|
||||
typename Functor::ScalarAccum,
|
||||
// The output tile size.
|
||||
typename GemmConfig_::OutputTile,
|
||||
// The number of warps.
|
||||
@ -242,6 +263,8 @@ struct GemmEpilogueTraitsHelper {
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
/// The stream to load D.
|
||||
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef GemmGlobalTileCdTraits<
|
||||
@ -314,8 +337,8 @@ struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The stream to load D from shared memory.
|
||||
typename Helper_::SharedLoadStreamD,
|
||||
// The number of iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
|
||||
@ -29,9 +29,10 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/iterator_access.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -39,6 +40,8 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Identifies multiplicand
|
||||
GemmOperand::Kind Operand,
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
@ -46,7 +49,9 @@ template <
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_>
|
||||
|
||||
struct GlobalLoadStreamBase {
|
||||
struct GlobalLoadStream {
|
||||
/// Indicates the type of GEMM operand
|
||||
static GemmOperand::Kind const kOperand = Operand;
|
||||
/// The load iterator.
|
||||
typedef LoadIterator_ LoadIterator;
|
||||
/// The transformer.
|
||||
@ -75,6 +80,15 @@ struct GlobalLoadStreamBase {
|
||||
typedef typename LoadIterator::Pointer Pointer;
|
||||
/// The index.
|
||||
typedef typename LoadIterator::Index Index;
|
||||
/// The tile
|
||||
typedef typename LoadIterator::Tile Tile;
|
||||
|
||||
/// Shared memory allocation for the tile
|
||||
typedef TileAllocation<typename StoreIterator::Scalar, typename StoreIterator::Tile>
|
||||
ThreadblockTileStorage;
|
||||
|
||||
/// Tensor reference to threadblock tile
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
@ -82,56 +96,73 @@ struct GlobalLoadStreamBase {
|
||||
typename LoadIterator::Params load_iterator;
|
||||
// The store iterator.
|
||||
typename StoreIterator::Params store_iterator;
|
||||
// Offset to residue.
|
||||
Index offset_to_residue;
|
||||
|
||||
/// Setup the params.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
|
||||
int error_code = load_iterator.initialize(desc, pointer, ld);
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
Index ldm,
|
||||
Index _offset_to_residue) {
|
||||
|
||||
offset_to_residue = _offset_to_residue;
|
||||
int error_code = load_iterator.initialize(pointer, batch_stride, ldm);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
return store_iterator.initialize();
|
||||
}
|
||||
};
|
||||
|
||||
/// The amount of storage in shared memory needed to store the tile.
|
||||
typedef typename StoreIterator::SharedStorage SharedStoreStorage;
|
||||
/// Contains private storage in shared memory needed by the objects within this class. Note,
|
||||
/// this is *NOT* the shared memory allocation for the GEMM threadblock tile. That necessarily
|
||||
/// exists outside this class, as it is also needed by the warp-level shared=>RF stream.
|
||||
struct SharedStorage {};
|
||||
|
||||
/// The storage in shared memory needed by that stream.
|
||||
union SharedStorage {
|
||||
// The load iterator.
|
||||
typename LoadIterator::SharedStorage load_iterator;
|
||||
// The store iterator.
|
||||
SharedStoreStorage store_iterator;
|
||||
};
|
||||
//
|
||||
// Static member functions
|
||||
//
|
||||
|
||||
/// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
|
||||
CUTLASS_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
|
||||
bool const kKstrided =
|
||||
GemmMultiplicandTraits<typename LoadIterator::Tile, kOperand, kLayout>::kKstrided;
|
||||
Coord<3> tile_coord = ProjectOperand<kOperand, kKstrided>::project(coord);
|
||||
return make_Coord(
|
||||
tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
|
||||
}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& block)
|
||||
: load_iterator(params.load_iterator, bounds, block),
|
||||
CUTLASS_DEVICE GlobalLoadStream(
|
||||
Params const& _params,
|
||||
SharedStorage& shared_storage,
|
||||
ThreadblockTileRef const& threadblock_tile_ref,
|
||||
Coord<3> const bounds,
|
||||
Coord<3> const& _threadblock_offset)
|
||||
: params(_params),
|
||||
multiplicand_bounds(project_coordinate(bounds, 1)),
|
||||
threadblock_offset(project_coordinate(_threadblock_offset)),
|
||||
load_iterator(params.load_iterator,
|
||||
project_coordinate(bounds, 1), /*multiplicant_bounds*/
|
||||
project_coordinate(_threadblock_offset) /*threablock_offset*/),
|
||||
transformer(),
|
||||
store_iterator(params.store_iterator, shared_storage.store_iterator)
|
||||
|
||||
store_iterator(params.store_iterator, threadblock_tile_ref.data())
|
||||
{
|
||||
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
|
||||
fetched_fragment.clear();
|
||||
}
|
||||
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); }
|
||||
CUTLASS_DEVICE void copy() { load_iterator.load_post_increment(fetched_fragment); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
transformer.transform(fetched_fragment, transformed_fragment);
|
||||
iterator_store(store_iterator, transformed_fragment);
|
||||
store_iterator.store_post_increment(transformed_fragment);
|
||||
store_iterator.inc_stage();
|
||||
}
|
||||
|
||||
/// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
load_iterator.residue(k);
|
||||
@ -140,9 +171,43 @@ struct GlobalLoadStreamBase {
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the GEMM-k dimension.
|
||||
CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
|
||||
/// Move to the residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
Index kResidue = k % kTileK;
|
||||
if (kResidue) {
|
||||
residue(kResidue);
|
||||
}
|
||||
load_iterator.add_pointer_offset(params.offset_to_residue * load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Rollback to the beginning of the first tile
|
||||
CUTLASS_DEVICE void rollback(void) {
|
||||
load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
|
||||
|
||||
int const kBlock = kOperand == GemmOperand::kA
|
||||
? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
|
||||
: (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
|
||||
|
||||
load_iterator.add_pointer_offset(-(params.offset_to_residue + kBlock) *
|
||||
load_iterator.stride_advance());
|
||||
}
|
||||
|
||||
/// Adds a Coord<3> to the underlying global load iterator
|
||||
CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
|
||||
load_iterator += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters
|
||||
Params params;
|
||||
/// Multiplicand bounds
|
||||
Coord<3> multiplicand_bounds;
|
||||
/// Threadblock offset
|
||||
Coord<3> threadblock_offset;
|
||||
/// The iterator.
|
||||
LoadIterator load_iterator;
|
||||
/// The fragment to fetch from shared memory.
|
||||
@ -155,28 +220,6 @@ struct GlobalLoadStreamBase {
|
||||
StoreIterator store_iterator;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The load iterator.
|
||||
typename LoadIterator_,
|
||||
/// The store iterator to copy to shared memory.
|
||||
typename StoreIterator_,
|
||||
/// The transformer to be applied after the data has been copied from global memory.
|
||||
typename Transformer_ = Copy<typename LoadIterator_::Fragment> >
|
||||
|
||||
struct GlobalLoadStream : public GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> {
|
||||
/// The base class.
|
||||
typedef GlobalLoadStreamBase<LoadIterator_, StoreIterator_, Transformer_> Base;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params,
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
Coord<3> const& bounds,
|
||||
Coord<3> const& block)
|
||||
: Base(params, shared_storage, bounds, block) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -27,14 +27,14 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -80,20 +80,24 @@ struct GemmGlobalTileTraits {
|
||||
static int const kAccessSize = kAccessSize_;
|
||||
/// The memory space.
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
|
||||
|
||||
/// The tile shape
|
||||
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile Tile;
|
||||
typedef Tile_ Tile;
|
||||
/// The vectorized tile shape
|
||||
typedef typename ReshapeTile<Tile_, kAccessSize_>::Tile VectorizedTile;
|
||||
/// The threads shape
|
||||
typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
|
||||
typedef typename ReshapeThreads<VectorizedTile, Threads_>::Threads Threads;
|
||||
/// The relative offset between two elements in the H/W dimension in adjacent threads.
|
||||
typedef Shape<1, 1, Tile::kC> ThreadsDelta;
|
||||
|
||||
typedef Shape<1, 1, VectorizedTile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize>
|
||||
typedef Shape<1,
|
||||
VectorizedTile::kH / Threads::kH,
|
||||
VectorizedTile::kW / Threads::kW,
|
||||
VectorizedTile::kC / kAccessSize>
|
||||
Iterations;
|
||||
|
||||
typedef GemmMultiplicandTraits<Tile, kOperand, kLayout> MultiplicandTraits;
|
||||
@ -165,7 +169,6 @@ struct GemmGlobalIteratorAb
|
||||
Index_> {
|
||||
/// This class.
|
||||
typedef GemmGlobalIteratorAb<TileTraits_, Index_> This_; /// The base class.
|
||||
|
||||
typedef TileLoadIterator<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
|
||||
@ -175,6 +178,8 @@ struct GemmGlobalIteratorAb
|
||||
Base;
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
|
||||
/// The tile
|
||||
typedef typename TileTraits_::Tile Tile;
|
||||
/// Fragment type loaded by the iterator
|
||||
typedef typename Base::Fragment Fragment;
|
||||
/// The scalar.
|
||||
@ -195,8 +200,9 @@ struct GemmGlobalIteratorAb
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Initializes params to load a strip-mined tile, given pointer and stride_h.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr,
|
||||
long long stride_d,
|
||||
Index stride_h) {
|
||||
Index inc_d = 0;
|
||||
Index inc_advance = 0;
|
||||
// Move by some columns for each iteration in the H dimension.
|
||||
@ -221,99 +227,36 @@ struct GemmGlobalIteratorAb
|
||||
(Base::Iterations::kH - 1) * inc_h;
|
||||
}
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
|
||||
// Move to the residue.
|
||||
Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
// The jump in the gemm-k dimension.
|
||||
Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
|
||||
|
||||
// Compute the offset to the residue and how to "come" back.
|
||||
Index const kResidue = desc.k % kBlock;
|
||||
if (kResidue > 0) {
|
||||
move_to_residue_offset = (desc.k - kResidue) * stride;
|
||||
} else {
|
||||
move_to_residue_offset = (desc.k - kBlock) * stride;
|
||||
}
|
||||
|
||||
Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
Base::Params::initialize(
|
||||
ptr, stride_d, stride_h, 1, inc_d, inc_h, 0, inc_advance);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The extra offset to control moving to the residue.
|
||||
Index move_to_residue_offset;
|
||||
};
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// The column.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
Index block_w = thread_offset[2];
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The predicates.
|
||||
PredicateVector predicates;
|
||||
|
||||
// Add the blocks indices.
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_h += block[1];
|
||||
block_w += block[2];
|
||||
|
||||
} else {
|
||||
block_h += block[2];
|
||||
block_w += block[1];
|
||||
}
|
||||
|
||||
// Setup the pointer.
|
||||
params.pointer += (block_h * params.stride_h + block_w);
|
||||
|
||||
// Initialize predicates
|
||||
initialize_predicates(bounds, make_Coord(0, block_h, block_w));
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Initialize the predicates.
|
||||
CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
|
||||
// Setup the masks to control loads.
|
||||
predicates.fill(0);
|
||||
|
||||
int bounds_h, bounds_w;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
bounds_w = bounds[2] - block[2];
|
||||
bounds_h = bounds[1];
|
||||
|
||||
} else {
|
||||
bounds_w = bounds[1];
|
||||
bounds_h = bounds[2] - block[1];
|
||||
}
|
||||
|
||||
// Fill in the bits of the predicate vector.
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
bool flag = w * Base::Delta::kW < bounds_w;
|
||||
bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
|
||||
flag =
|
||||
flag &&
|
||||
(h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
|
||||
bounds[1];
|
||||
} else {
|
||||
flag = flag && (h * Base::Delta::kH) < bounds_h;
|
||||
flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
|
||||
}
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
predicates.set(bit, flag);
|
||||
@ -323,31 +266,44 @@ struct GemmGlobalIteratorAb
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
// Store the pointer and the predicates.
|
||||
stored_pointer = params.pointer;
|
||||
stored_predicates = predicates;
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Setup the pointer.
|
||||
params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(threadblock_offset[2] + thread_offset[2]));
|
||||
|
||||
// Move the pointer to the residue.
|
||||
params.pointer += params.move_to_residue_offset;
|
||||
}
|
||||
|
||||
// The dimensions of the tile.
|
||||
int const kH = TileTraits_::Tile::kH;
|
||||
int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_w() { Base::inc_w(); }
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
// The unrolling factor.
|
||||
int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
|
||||
|
||||
// Clear the predicates for the residue. TODO: We can do something smarter.
|
||||
int const kResidue = (int)(k % (Index)kUnroll);
|
||||
if (kResidue > 0) {
|
||||
residue(kResidue);
|
||||
}
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_HOST_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// That's the residue! Update the predicates.
|
||||
CUTLASS_DEVICE void residue(Index k) {
|
||||
CUTLASS_HOST_DEVICE void residue(Index k) {
|
||||
// The coordinates of the thread.
|
||||
Index block_h = thread_offset[1];
|
||||
// The contiguous dimension.
|
||||
@ -375,26 +331,63 @@ struct GemmGlobalIteratorAb
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
params.pointer = stored_pointer;
|
||||
predicates = stored_predicates;
|
||||
}
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
/// Is the valid?
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
int const bit = ComputeOffsetFromShape<typename Base::Iterations>::get(d, h, w, c);
|
||||
return predicates[bit];
|
||||
}
|
||||
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The parameters
|
||||
Params params;
|
||||
/// The pointer.
|
||||
typename Base::Scalar const* stored_pointer;
|
||||
/// The predicates.
|
||||
PredicateVector predicates, stored_predicates;
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) {
|
||||
|
||||
long long _offset = offset.template dot<long long>(
|
||||
make_Coord(params.stride_d, params.stride_h, params.stride_w)
|
||||
);
|
||||
|
||||
params.pointer += _offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
|
||||
CUTLASS_HOST_DEVICE Index stride_advance(void) {
|
||||
Index stride = params.stride_h;
|
||||
if (kAdvance == IteratorAdvance::kW) {
|
||||
stride = params.stride_w;
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
load_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -433,6 +426,8 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the D dimension
|
||||
long long stride_d;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
@ -443,17 +438,23 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
Index predicate_offset;
|
||||
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index bound, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
Index ldm,
|
||||
Index bound,
|
||||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
// Stride per batch
|
||||
stride_d = batch_stride;
|
||||
// Each column of the matrix.
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ld;
|
||||
stride_h = TileTraits_::ThreadsDelta::kH * ldm;
|
||||
// Each thread output 1 column per iteration. The stride between columns is given by the
|
||||
// number of scalars that are loaded per LDS for B.
|
||||
inc_h = ld * TileTraits_::kStrideH;
|
||||
inc_h = ldm * TileTraits_::kStrideH;
|
||||
inc_advance =
|
||||
(ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
|
||||
(ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
|
||||
|
||||
predicate_offset = bound;
|
||||
predicate_inc_h = TileTraits_::kStrideH;
|
||||
@ -464,75 +465,173 @@ struct GemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters.
|
||||
Params params;
|
||||
/// Offset of an individual lane from the start of the tile
|
||||
Coord<4> thread_offset;
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd() {}
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, thread_offset[2] + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int offset = 0,
|
||||
int pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(params) {
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
int offset = 0,
|
||||
int pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: params(_params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + offset);
|
||||
params.pointer += ((h * params.stride_h + w) + offset);
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
params.predicate_offset -= (h + pred_offset);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
CUTLASS_HOST_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
CUTLASS_HOST_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
CUTLASS_HOST_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
CUTLASS_HOST_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
CUTLASS_HOST_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
this->params.predicate_offset -= params.predicate_inc_advance;
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) {
|
||||
long long _offset = offset.template dot<long long>(
|
||||
make_Coord(params.stride_d, params.stride_h, 1)
|
||||
);
|
||||
params.pointer += _offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Test the validity of the iterator.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
/// Loads a single fragment element from memory.
|
||||
CUTLASS_HOST_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Stores a single fragment element into memory.
|
||||
CUTLASS_HOST_DEVICE void store_element(
|
||||
typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW,
|
||||
Base::kAccessSize * sizeof(Scalar)>::store(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Test the validity of the
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
/// add pointer offset
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
|
||||
/// Loads and increments iterator
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
load_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment& fragment) {
|
||||
typename Base::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < Base::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Base::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Base::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Base::Iterations::kC; ++c) {
|
||||
if (valid(d, h, w, c)) {
|
||||
store_element(
|
||||
reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < Base::Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Base::Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Base::Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -28,9 +28,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/util/platform.h>
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
@ -28,7 +28,8 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -56,6 +57,11 @@ struct SharedLoadStream {
|
||||
"");
|
||||
/// The output fragment.
|
||||
typedef TransformedFragment Fragment;
|
||||
/// Scalar data type
|
||||
typedef typename Iterator::Scalar Scalar;
|
||||
|
||||
/// Reference type to a tensor
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
@ -73,29 +79,38 @@ struct SharedLoadStream {
|
||||
CUTLASS_DEVICE SharedLoadStream() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
this->initialize(params, shared_storage);
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, TensorRef const &ref) {
|
||||
this->initialize(params, ref);
|
||||
}
|
||||
|
||||
/// Initialize the stream.
|
||||
CUTLASS_DEVICE void initialize(Params const ¶ms, SharedStorage &shared_storage) {
|
||||
CUTLASS_DEVICE void initialize(Params const ¶ms, TensorRef const &ref) {
|
||||
// The iterator.
|
||||
iterator = Iterator(params.iterator, shared_storage);
|
||||
iterator = Iterator(params.iterator, ref.data());
|
||||
// The transformer.
|
||||
transformer = Transformer();
|
||||
}
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(FetchedFragment &fetched) { shared_iterator_load(iterator, fetched); }
|
||||
CUTLASS_DEVICE void copy() { iterator.load_post_increment(fetched[0]); }
|
||||
|
||||
/// Load the data from shared memory to the fetch fragment.
|
||||
CUTLASS_DEVICE void copy(int d, FetchedFragment &fetched) {
|
||||
shared_iterator_load(iterator, fetched, d);
|
||||
}
|
||||
CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(FetchedFragment &fetched, TransformedFragment &transformed) {
|
||||
transformer.transform(fetched, transformed);
|
||||
CUTLASS_DEVICE void commit() { transformer.transform(fetched[0], transformed[0]); }
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
transformer.transform(fetched[step % 2], transformed[step % 2]);
|
||||
}
|
||||
|
||||
/// Returns the fragment for the given step
|
||||
CUTLASS_DEVICE TransformedFragment &fragment(int step = 0) { return transformed[step % 2]; }
|
||||
|
||||
/// Returns the fragment for the given step
|
||||
CUTLASS_DEVICE TransformedFragment const &fragment(int step = 0) const {
|
||||
return transformed[step % 2];
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
@ -103,8 +118,12 @@ struct SharedLoadStream {
|
||||
|
||||
/// The iterator.
|
||||
Iterator iterator;
|
||||
/// Fetched fragment
|
||||
FetchedFragment fetched[2];
|
||||
/// The transformer.
|
||||
Transformer transformer;
|
||||
/// Transformed fragment
|
||||
TransformedFragment transformed[2];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
251
cutlass/gemm/gemm_stream_pair.h
Normal file
@ -0,0 +1,251 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a pair of GEMM tile streams
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Collect the global load streams for multiplicands.
|
||||
template <typename StreamA_, typename StreamB_, bool kResidueInProlog_>
|
||||
struct GlobalLoadStreamPair {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
typedef StreamA_ StreamA;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
typedef StreamB_ StreamB;
|
||||
|
||||
/// Parameters object
|
||||
struct Params {
|
||||
/// Parameters object for StreamA
|
||||
typename StreamA::Params stream_a;
|
||||
|
||||
/// Parameters object for StreamB
|
||||
typename StreamB::Params stream_b;
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Constructs a global load stream pair Params object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(typename StreamA::Params const &_params_A, typename StreamB::Params const &_params_B)
|
||||
: stream_a(_params_A), stream_b(_params_B) {}
|
||||
};
|
||||
|
||||
/// Assumes the A stream defines the index type
|
||||
typedef typename StreamA::Index Index;
|
||||
|
||||
/// Shared memory allocation for threadblock-scoped GEMM tile
|
||||
typedef ZipTileAllocation<typename StreamA::ThreadblockTileStorage,
|
||||
typename StreamB::ThreadblockTileStorage>
|
||||
ThreadblockTileStorage;
|
||||
|
||||
/// ZipTensorRef to threadblock tiles
|
||||
typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
|
||||
|
||||
/// Defines a structure containing shared storage for each pair
|
||||
struct SharedStorage {
|
||||
typename StreamA::SharedStorage stream_a;
|
||||
typename StreamB::SharedStorage stream_b;
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
StreamA stream_a;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
StreamB stream_b;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStreamPair(Params const ¶ms,
|
||||
SharedStorage &shared_storage,
|
||||
ThreadblockTileRef const &threadblock_tile_ref,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0))
|
||||
: stream_a(params.stream_a,
|
||||
shared_storage.stream_a,
|
||||
threadblock_tile_ref.first,
|
||||
bounds,
|
||||
block_offset),
|
||||
stream_b(params.stream_b,
|
||||
shared_storage.stream_b,
|
||||
threadblock_tile_ref.second,
|
||||
bounds,
|
||||
block_offset) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GlobalLoadStreamPair & operator+=(Coord<3> const offset) {
|
||||
stream_a += offset;
|
||||
stream_b += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
stream_a.copy();
|
||||
stream_b.copy();
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
stream_b.commit();
|
||||
}
|
||||
|
||||
/// Execute the residue code.
|
||||
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
|
||||
stream_a.residue(k, skip_clear);
|
||||
stream_b.residue(k, skip_clear);
|
||||
}
|
||||
|
||||
/// Move to residue.
|
||||
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
|
||||
if (kResidueInProlog_) {
|
||||
stream_a.move_to_residue(k, kTileK);
|
||||
stream_b.move_to_residue(k, kTileK);
|
||||
} else if (k < kTileK) {
|
||||
residue(k, true);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile.
|
||||
CUTLASS_DEVICE void rollback(bool kRollback) {
|
||||
if (kResidueInProlog_ && kRollback) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Collect the global load streams for multiplicands.
|
||||
template <typename StreamA_, typename StreamB_>
|
||||
struct SharedStreamPair {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Stream for A multiplicand
|
||||
typedef StreamA_ StreamA;
|
||||
|
||||
/// Stream for B multiplicand
|
||||
typedef StreamB_ StreamB;
|
||||
|
||||
/// Parameters object passed to load iterators
|
||||
struct Params {
|
||||
///
|
||||
typename StreamA::Params stream_a;
|
||||
|
||||
///
|
||||
typename StreamB::Params stream_b;
|
||||
};
|
||||
|
||||
/// Shared memory allocation for threadblock-scoped GEMM tile
|
||||
typedef ZipTensorRef<typename StreamA::TensorRef,
|
||||
typename StreamB::TensorRef >
|
||||
ThreadblockTileRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// The stream for A.
|
||||
StreamA stream_a;
|
||||
|
||||
/// The stream for B.
|
||||
StreamB stream_b;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Construct with the composable structure
|
||||
CUTLASS_DEVICE SharedStreamPair(Params const ¶ms, ThreadblockTileRef const &threadblock_tile_ref)
|
||||
: stream_a(params.stream_a, threadblock_tile_ref.first),
|
||||
stream_b(params.stream_b, threadblock_tile_ref.second) {}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
stream_a.copy(step);
|
||||
stream_b.copy(step);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
stream_a.commit(step);
|
||||
stream_b.commit(step);
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE
|
||||
typename StreamA::TransformedFragment const &fragment_a(int step) const {
|
||||
return stream_a.fragment(step);
|
||||
}
|
||||
|
||||
/// The fragment B.
|
||||
CUTLASS_DEVICE
|
||||
typename StreamB::TransformedFragment const &fragment_b(int step) const {
|
||||
return stream_b.fragment(step);
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
stream_a.inc_stage();
|
||||
stream_b.inc_stage();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -27,117 +27,27 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/clear_accumulators.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/identity_block_swizzle.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_allocation.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/kernel_launch.h"
|
||||
|
||||
#include "cutlass/gemm/clear_accumulators.h"
|
||||
#include "cutlass/gemm/gemm_config.h"
|
||||
#include "cutlass/gemm/gemm_desc.h"
|
||||
#include "cutlass/gemm/gemm_stream_pair.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/threadblock_swizzle.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The scalar type for A.
|
||||
typename ScalarA_,
|
||||
/// The scalar type for B.
|
||||
typename ScalarB_,
|
||||
/// The scalar type for C.
|
||||
typename ScalarC_,
|
||||
/// The scalar type for D.
|
||||
typename ScalarD_,
|
||||
/// The output tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math.
|
||||
typename MultiplyAdd_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
int kScalarsPerStsA_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdsA_,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
int kScalarsPerStsB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
int kScalarsPerLdsB_,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
int kScalarsPerLdgCAndStgD_,
|
||||
/// The number of scalars per STS for D.
|
||||
int kScalarsPerStsD_,
|
||||
/// The number of scalars per LDS for D.
|
||||
int kScalarsPerLdsD_,
|
||||
/// The number of stages in shared memory to do single/double/triple-buffering.
|
||||
int kStages_,
|
||||
/// Do we do the residue in the prologue?
|
||||
bool kResidueInPrologue_ = false>
|
||||
|
||||
struct GemmConfig {
|
||||
//
|
||||
/// The scalar for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The scalar for B.
|
||||
typedef ScalarB_ ScalarB;
|
||||
/// The scalar for C.
|
||||
typedef ScalarC_ ScalarC;
|
||||
/// The scalar for D.
|
||||
typedef ScalarD_ ScalarD;
|
||||
|
||||
/// The tile.
|
||||
typedef OutputTile_ OutputTile;
|
||||
/// The functor to do D = A*B + C.
|
||||
typedef MultiplyAdd_ MultiplyAdd;
|
||||
/// The shape of the instruction.
|
||||
typedef typename MultiplyAdd::InstructionShape InstructionShape;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
|
||||
/// The accumulators.
|
||||
typedef typename MultiplyAdd::Accumulators Accumulators;
|
||||
|
||||
/// The number of warps.
|
||||
typedef typename ShapeDiv<OutputTile, AccumulatorsPerWarp>::Shape Warps;
|
||||
/// The default warp size (32 threads per warp).
|
||||
static int const kWarpSize = cutlass::kWarpSize;
|
||||
/// The numnber of threads.
|
||||
static int const kThreads = ShapeCount<Warps>::kCount * kWarpSize;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for A.
|
||||
static int const kScalarsPerLdgA = kScalarsPerLdgA_;
|
||||
static int const kScalarsPerStsA = kScalarsPerStsA_;
|
||||
static int const kScalarsPerLdsA = kScalarsPerLdsA_;
|
||||
|
||||
/// The number of scalars per LDG/STS/LDS for B.
|
||||
static int const kScalarsPerLdgB = kScalarsPerLdgB_;
|
||||
static int const kScalarsPerStsB = kScalarsPerStsB_;
|
||||
static int const kScalarsPerLdsB = kScalarsPerLdsB_;
|
||||
|
||||
/// The number of scalars per LDG for C.
|
||||
static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
|
||||
|
||||
/// The number of scalars per STS/LDS/STG for D.
|
||||
static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
|
||||
static int const kScalarsPerStsD = kScalarsPerStsD_;
|
||||
static int const kScalarsPerLdsD = kScalarsPerLdsD_;
|
||||
|
||||
/// The number of accumulators that are going to be fed from one LDS A/B.
|
||||
static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
|
||||
static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
|
||||
|
||||
/// The number of stages in shared memory to implement double, triple, more-buffering.
|
||||
static int const kStages = kStages_;
|
||||
|
||||
/// Do we do the residue in the prologue?
|
||||
static bool const kResidueInPrologue = kResidueInPrologue_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind, typename GemmConfig_>
|
||||
struct GemmTileTraitsHelperA {};
|
||||
|
||||
@ -416,60 +326,6 @@ struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_, bool kResidueInPrologue_ = GemmTraits_::kResidueInPrologue>
|
||||
struct GemmResidue {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have
|
||||
// complete main loops after that. It helps simplify the logic in the main loop.
|
||||
if (kIsPrologue) {
|
||||
stream_a.move_to_residue(k);
|
||||
stream_b.move_to_residue(k);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {
|
||||
stream_a.rollback();
|
||||
stream_b.rollback();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmTraits_>
|
||||
struct GemmResidue<GemmTraits_, false> {
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b,
|
||||
typename GemmTraits_::Index k) {
|
||||
// The index.
|
||||
typedef typename GemmTraits_::Index Index;
|
||||
// By how much we unroll the main loop.
|
||||
Index const kUnroll = static_cast<Index>(GemmTraits_::OutputTile::kD);
|
||||
|
||||
// Call the residue code. That's the same path as CUTLASS 1.0.0.
|
||||
if (kIsPrologue && k < kUnroll) {
|
||||
stream_a.residue(k, true);
|
||||
stream_b.residue(k, true);
|
||||
} else if (k <= kUnroll) {
|
||||
stream_a.residue(k, false);
|
||||
stream_b.residue(k, false);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a,
|
||||
typename GemmTraits_::GlobalLoadStreamB& stream_b) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The GEMM configuration.
|
||||
typename GemmConfig_,
|
||||
@ -488,27 +344,27 @@ template <
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The tool used to clear accumulators.
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Scalar> >
|
||||
typename ClearAccumulators_ = ClearAccumulators<typename GemmConfig_::Accumulators::Element> >
|
||||
|
||||
struct GemmTraits {
|
||||
/// This class.
|
||||
/// This traits
|
||||
typedef GemmTraits<GemmConfig_,
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_>
|
||||
This_;
|
||||
GlobalLoadStreamA_,
|
||||
GlobalLoadStreamB_,
|
||||
SharedLoadStreamA_,
|
||||
SharedLoadStreamB_,
|
||||
Epilogue_,
|
||||
BlockSwizzle_,
|
||||
Index_,
|
||||
ClearAccumulators_> This_;
|
||||
|
||||
/// The struct that consumes this Traits
|
||||
typedef typename cutlass::gemm::Gemm<This_> KernelClass;
|
||||
|
||||
/// The configuration.
|
||||
typedef GemmConfig_ GemmConfig;
|
||||
/// The output tile.
|
||||
typedef typename GemmConfig::OutputTile OutputTile;
|
||||
/// Is the residue treated in the prologue?
|
||||
static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue;
|
||||
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStreamA_ GlobalLoadStreamA;
|
||||
@ -544,18 +400,30 @@ struct GemmTraits {
|
||||
/// Clear the accumulators.
|
||||
typedef ClearAccumulators_ ClearAccumulators;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The dimensions of the GEMM.
|
||||
Index m, n, k;
|
||||
/// The params for the A stream.
|
||||
typename GlobalLoadStreamA::Params global_stream_a;
|
||||
/// The params for the B stream.
|
||||
typename GlobalLoadStreamB::Params global_stream_b;
|
||||
/// The params for the A stream from shared memory.
|
||||
typename SharedLoadStreamA::Params shared_stream_a;
|
||||
/// The params for the B stream from shared memory.
|
||||
typename SharedLoadStreamB::Params shared_stream_b;
|
||||
/// Assemble the global load streams for A/B.
|
||||
typedef GlobalLoadStreamPair<GlobalLoadStreamA,
|
||||
GlobalLoadStreamB,
|
||||
GemmConfig::kResidueInProlog>
|
||||
GlobalLoadStream;
|
||||
|
||||
/// Memory needed to store the threadblock-scoped GEMM tile
|
||||
typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage;
|
||||
|
||||
/// Assemble the shared load streams for A/B.
|
||||
typedef SharedStreamPair<SharedLoadStreamA, SharedLoadStreamB> SharedStream;
|
||||
|
||||
/// Parameters object constructable on the host.
|
||||
struct Params : public KernelLaunchConfiguration {
|
||||
|
||||
/// GEMM problem size
|
||||
GemmCoord problem_size;
|
||||
|
||||
/// Parameters object for the global load stream
|
||||
typename GlobalLoadStream::Params global_to_shared_stream;
|
||||
|
||||
/// Parameters object for the shared load stream
|
||||
typename SharedStream::Params shared_stream;
|
||||
|
||||
/// The params for the epilogue.
|
||||
typename Epilogue::Params epilogue;
|
||||
|
||||
@ -563,21 +431,36 @@ struct GemmTraits {
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
// Set the problem size.
|
||||
this->m = desc.m;
|
||||
this->n = desc.n;
|
||||
this->k = desc.k;
|
||||
problem_size = desc.problem_size;
|
||||
|
||||
// Initialize the iterator for A.
|
||||
int error_code =
|
||||
global_stream_a.initialize(desc, reinterpret_cast<ScalarA const*>(desc.d_a), desc.lda);
|
||||
// Compute grid dimensions
|
||||
BlockSwizzle block_swizzle;
|
||||
this->block = dim3(GemmConfig::kThreads);
|
||||
this->grid = block_swizzle.get_grid_layout(
|
||||
problem_size,
|
||||
make_Coord_from_shape<OutputTile>());
|
||||
|
||||
// Compute offset to residue.
|
||||
Index gemm_k = problem_size[0];
|
||||
Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
|
||||
|
||||
// Initialize parameters objects for
|
||||
int error_code = global_to_shared_stream.stream_a.initialize(
|
||||
desc.A.data(),
|
||||
desc.batch_stride_A,
|
||||
desc.A.leading_dim(),
|
||||
offset_to_residue
|
||||
);
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
}
|
||||
|
||||
// Initialize the iterator for B.
|
||||
error_code =
|
||||
global_stream_b.initialize(desc, reinterpret_cast<ScalarB const*>(desc.d_b), desc.ldb);
|
||||
error_code = global_to_shared_stream.stream_b.initialize(
|
||||
desc.B.data(),
|
||||
desc.batch_stride_B,
|
||||
desc.B.leading_dim(),
|
||||
offset_to_residue
|
||||
);
|
||||
|
||||
if (error_code) {
|
||||
return error_code;
|
||||
@ -586,24 +469,81 @@ struct GemmTraits {
|
||||
// The epilogue.
|
||||
return epilogue.initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for A.
|
||||
template <typename GlobalLoadStream_, typename SharedLoadStream_>
|
||||
union StreamSharedStorage {
|
||||
// The storage needed by the global stream.
|
||||
typename GlobalLoadStream_::SharedStorage global;
|
||||
// The storage needed by the shared stream.
|
||||
typename SharedLoadStream_::SharedStorage shared;
|
||||
/// Helper to construct a GEMM params using a BLAS-like API
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
ScalarD* d_d,
|
||||
Index ldd) {
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, 1),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
TensorRef<ScalarD, 2>(d_d, ldd)
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
|
||||
/// Helper to construct a batched GEMM params
|
||||
CUTLASS_HOST_DEVICE int initialize(Index m,
|
||||
Index n,
|
||||
Index k,
|
||||
typename Epilogue::Scalar alpha,
|
||||
ScalarA const* d_a,
|
||||
Index lda,
|
||||
long long int batch_stride_A,
|
||||
ScalarB const* d_b,
|
||||
Index ldb,
|
||||
long long int batch_stride_B,
|
||||
typename Epilogue::Scalar beta,
|
||||
ScalarC const* d_c,
|
||||
Index ldc,
|
||||
long long int batch_stride_C,
|
||||
ScalarD* d_d,
|
||||
Index ldd,
|
||||
long long int batch_stride_D,
|
||||
Index batch_count) {
|
||||
|
||||
GemmDesc<ScalarA, ScalarB, ScalarC, ScalarD, typename Epilogue::Scalar> desc(
|
||||
GemmCoord(k, n, m, batch_count),
|
||||
alpha,
|
||||
TensorRef<ScalarA const, 2>(d_a, lda),
|
||||
batch_stride_A,
|
||||
TensorRef<ScalarB const, 2>(d_b, ldb),
|
||||
batch_stride_B,
|
||||
beta,
|
||||
TensorRef<ScalarC const, 2>(d_c, ldc),
|
||||
batch_stride_C,
|
||||
TensorRef<ScalarD, 2>(d_d, ldd),
|
||||
batch_stride_D
|
||||
);
|
||||
|
||||
return this->initialize(desc);
|
||||
}
|
||||
};
|
||||
|
||||
// The storage for the main loop + prologue.
|
||||
struct MainLoopSharedStorage {
|
||||
// The storage to shuffle the A matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamA, SharedLoadStreamA> stream_a;
|
||||
// The storage to shuffle the B matrix in shared memory.
|
||||
StreamSharedStorage<GlobalLoadStreamB, SharedLoadStreamB> stream_b;
|
||||
// The storage to clear the accumulators if needed.
|
||||
/// Stores the threadblock tile
|
||||
ThreadblockTileStorage threadblock_tile;
|
||||
|
||||
/// Storage for GEMM global stream
|
||||
typename GlobalLoadStream::SharedStorage global_to_shared_stream;
|
||||
|
||||
/// Storage for clearing accumulators
|
||||
typename ClearAccumulators::SharedStorage clear;
|
||||
};
|
||||
|
||||
@ -615,108 +555,18 @@ struct GemmTraits {
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
/// Assemble the global load streams for A/B.
|
||||
struct GlobalLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE GlobalLoadStream(Params const& params,
|
||||
SharedStorage& shared_storage,
|
||||
dim3 const& block)
|
||||
: stream_a(params.global_stream_a,
|
||||
shared_storage.main_loop.stream_a.global,
|
||||
cutlass::make_Coord(0, params.k, params.m),
|
||||
cutlass::make_Coord(0, 0, block.x)),
|
||||
stream_b(params.global_stream_b,
|
||||
shared_storage.main_loop.stream_b.global,
|
||||
cutlass::make_Coord(0, params.k, params.n),
|
||||
make_Coord(0, 0, block.y)) {}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy() {
|
||||
stream_a.copy();
|
||||
stream_b.copy();
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit() {
|
||||
stream_a.commit();
|
||||
stream_b.commit();
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
template <bool kIsPrologue>
|
||||
CUTLASS_DEVICE void move_to_residue(Index k) {
|
||||
GemmResidue<This_>::move_to_residue<kIsPrologue>(stream_a, stream_b, k);
|
||||
}
|
||||
|
||||
/// Rollback to beginning of first tile and initialize predicates.
|
||||
CUTLASS_DEVICE void rollback() { GemmResidue<This_>::rollback(stream_a, stream_b); }
|
||||
|
||||
/// The stream for A.
|
||||
GlobalLoadStreamA stream_a;
|
||||
/// The stream for B.
|
||||
GlobalLoadStreamB stream_b;
|
||||
};
|
||||
|
||||
/// Assemble the shared load stream for A/B.
|
||||
struct SharedLoadStream {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE SharedLoadStream(Params const& params, SharedStorage& shared_storage) {
|
||||
stream_a.initialize(params.shared_stream_a, shared_storage.main_loop.stream_a.shared);
|
||||
stream_b.initialize(params.shared_stream_b, shared_storage.main_loop.stream_b.shared);
|
||||
}
|
||||
|
||||
/// Trigger the copies from shared memory to registers.
|
||||
CUTLASS_DEVICE void copy(int step) {
|
||||
stream_a.copy(step, fetched_a[step % 2]);
|
||||
stream_b.copy(step, fetched_b[step % 2]);
|
||||
}
|
||||
|
||||
/// Commit the data.
|
||||
CUTLASS_DEVICE void commit(int step) {
|
||||
stream_a.commit(fetched_a[step % 2], transformed_a[step % 2]);
|
||||
stream_b.commit(fetched_b[step % 2], transformed_b[step % 2]);
|
||||
}
|
||||
|
||||
/// The fragment A.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamA::Fragment const& fragment_a(int step) const {
|
||||
return transformed_a[step % 2];
|
||||
}
|
||||
|
||||
/// The fragment B.
|
||||
CUTLASS_DEVICE typename SharedLoadStreamB::Fragment const& fragment_b(int step) const {
|
||||
return transformed_b[step % 2];
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
stream_a.inc_stage();
|
||||
stream_b.inc_stage();
|
||||
}
|
||||
|
||||
/// The stream for A.
|
||||
SharedLoadStreamA stream_a;
|
||||
/// The fragments to fetch A.
|
||||
typename SharedLoadStreamA::FetchedFragment fetched_a[2];
|
||||
/// The fragments to transform A.
|
||||
typename SharedLoadStreamA::TransformedFragment transformed_a[2];
|
||||
/// The stream for B.
|
||||
SharedLoadStreamB stream_b;
|
||||
/// The fragments to fetch B.
|
||||
typename SharedLoadStreamB::FetchedFragment fetched_b[2];
|
||||
/// The fragments to transform B.
|
||||
typename SharedLoadStreamB::TransformedFragment transformed_b[2];
|
||||
};
|
||||
|
||||
/// The memory fence for shared loads.
|
||||
static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
|
||||
if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
|
||||
SharedLoadStreamB::Iterator::kRequiresLoadFence) {
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
/// The memory fence for shared stores.
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); }
|
||||
static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -735,7 +585,10 @@ struct SimplifiedGemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The global iterator to load B from global memory.
|
||||
@ -750,7 +603,10 @@ struct SimplifiedGemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
|
||||
@ -29,10 +29,10 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -63,14 +63,14 @@ struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
/// The threads.
|
||||
typedef typename Base::Threads Threads;
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 2, Base::Tile::kC> ThreadsDelta;
|
||||
typedef Shape<1, 2, Base::VectorizedTile::kC> ThreadsDelta;
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 2, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 2,
|
||||
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 2,
|
||||
2,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Base::VectorizedTile::kW / Base::Threads::kW,
|
||||
Base::VectorizedTile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
struct ThreadOffset {
|
||||
|
||||
@ -28,9 +28,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -38,16 +38,18 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half> {
|
||||
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 2, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased for compatibility. Will be removed for CUTLASS v2.0.
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef half ScalarA;
|
||||
/// The fragment for A.
|
||||
@ -88,9 +90,9 @@ struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, half, half, ha
|
||||
int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
|
||||
|
||||
// Compute the product a[i] * b[j].H0_H0.
|
||||
// Compute the product a[i] * b[j].low.
|
||||
d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
|
||||
// Compute the product a[i] * b[j].H1_H1.
|
||||
// Compute the product a[i] * b[j].high.
|
||||
d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
@ -27,18 +27,18 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/hgemm_global_tile.h>
|
||||
#include <cutlass/gemm/hgemm_multiply_add.h>
|
||||
#include <cutlass/gemm/hgemm_swizzle.h>
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/hgemm_global_tile.h"
|
||||
#include "cutlass/gemm/hgemm_multiply_add.h"
|
||||
#include "cutlass/gemm/hgemm_swizzle.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -48,46 +48,52 @@ namespace gemm {
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 2>
|
||||
struct HgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
half,
|
||||
/// The scalar type for D.
|
||||
half,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, half, half, half>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
2,
|
||||
/// The number of scalars per STS for D.
|
||||
8,
|
||||
/// The number of scalars per LDS for D.
|
||||
2,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
struct HgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
/// The scalar type for C.
|
||||
half,
|
||||
/// The scalar type for D.
|
||||
half,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, half, half, half>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
8,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
8,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
2,
|
||||
/// The number of scalars per STS for D.
|
||||
8,
|
||||
/// The number of scalars per LDS for D.
|
||||
2,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
false
|
||||
> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -147,7 +153,6 @@ struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^T.
|
||||
@ -215,7 +220,6 @@ struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew for B.
|
||||
static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
@ -266,8 +270,8 @@ template <
|
||||
typename OutputTile_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
@ -276,8 +280,7 @@ template <
|
||||
typename Index_ = int>
|
||||
struct HgemmTraitsHelper {
|
||||
/// The HGEMM config.
|
||||
typedef HgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
typedef HgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef HgemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
@ -296,7 +299,10 @@ struct HgemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
@ -312,7 +318,10 @@ struct HgemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory
|
||||
@ -354,8 +363,8 @@ template <
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<half>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 16>,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 2,
|
||||
/// The number of halfs loaded in one LDG for B.
|
||||
@ -367,7 +376,7 @@ template <
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
ThreadGemmShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
Index_> >
|
||||
|
||||
@ -28,13 +28,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/igemm_global_tile.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -269,8 +269,8 @@ struct IgemmEpilogueTraits : public GemmEpilogueTraits<
|
||||
typename Helper_::SharedStoreIteratorD,
|
||||
// The shared store transformer for D.
|
||||
typename Helper_::SharedStoreTransformerD,
|
||||
// The iterator to load D from shared memory.
|
||||
typename Helper_::SharedLoadIteratorD,
|
||||
// The stream to load D from shared memory.
|
||||
typename Helper_::SharedLoadStreamD,
|
||||
// The iterations.
|
||||
typename Helper_::Iterations,
|
||||
// The strides between iterations.
|
||||
@ -294,9 +294,8 @@ struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
Coord<3> const& _problem_size)
|
||||
: Base(params_, shared_storage_, _problem_size) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -309,9 +308,8 @@ struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilog
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
|
||||
typename Base::SharedStorage& shared_storage_,
|
||||
typename Base::Index m_,
|
||||
typename Base::Index n_)
|
||||
: Base(params_, shared_storage_, m_, n_) {}
|
||||
Coord<3> const& _problem_size)
|
||||
: Base(params_, shared_storage_, _problem_size) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -32,9 +32,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -67,10 +67,10 @@ struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
/// The strides in each dimension between different loads/stores.
|
||||
typedef Shape<Base::Threads::kH * 4, 1, Base::Threads::kW, Base::kAccessSize> Delta;
|
||||
/// The number of iterations needed to load/store the tile.
|
||||
typedef Shape<Base::Tile::kH / Base::Threads::kH / 4,
|
||||
typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 4,
|
||||
4,
|
||||
Base::Tile::kW / Base::Threads::kW,
|
||||
Base::Tile::kC / Base::kAccessSize>
|
||||
Base::VectorizedTile::kW / Base::Threads::kW,
|
||||
Base::VectorizedTile::kC / Base::kAccessSize>
|
||||
Iterations;
|
||||
|
||||
/// Computes the thread offset in (H, W) based on thread ID
|
||||
@ -86,24 +86,11 @@ struct IgemmGlobalTileTraits : public GemmGlobalTileTraits<
|
||||
|
||||
public:
|
||||
/// The threads strides.
|
||||
typedef Shape<1, 4, Base::Tile::kC> ThreadsDelta;
|
||||
typedef Shape<1, 4, Base::VectorizedTile::kC> ThreadsDelta;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Deprecated. Please use IgemmGlobalTileTraits instead.
|
||||
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename Tile_,
|
||||
typename Threads_,
|
||||
int kAccessSize_>
|
||||
struct IgemmContiguousGlobalTileTraits
|
||||
: public IgemmGlobalTileTraits<kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
|
||||
/// The base class.
|
||||
@ -114,11 +101,11 @@ struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_>
|
||||
/// Constructor.
|
||||
CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
|
||||
const Coord<3>& bounds,
|
||||
const Coord<3>& block,
|
||||
const Coord<3>& threadblock_offset,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) {
|
||||
: Base(_params, bounds, threadblock_offset, thread_offset_func), mask_(0xffffffff) {
|
||||
// The number of elements read in a single iteration.
|
||||
int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
|
||||
int const kBlock = TileTraits_::Tile::kW;
|
||||
// The residue.
|
||||
int const kResidue = (int)(bounds[1] % kBlock);
|
||||
|
||||
@ -129,28 +116,12 @@ struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_>
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::get(value, d, h, w, c);
|
||||
if (in_residue_) {
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
CUTLASS_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::load_element(value, d, h, w, c);
|
||||
reinterpret_cast<uint32_t&>(value) &= mask_;
|
||||
}
|
||||
|
||||
/// Move to residue portion.
|
||||
CUTLASS_DEVICE void move_to_residue(typename Base::Index k) {
|
||||
Base::move_to_residue(k);
|
||||
in_residue_ = true;
|
||||
}
|
||||
|
||||
/// Move back to the beginning of the first tile.
|
||||
CUTLASS_DEVICE void rollback() {
|
||||
Base::rollback();
|
||||
in_residue_ = false;
|
||||
}
|
||||
|
||||
/// Are we in the residue?
|
||||
bool in_residue_;
|
||||
/// The mask to clean up the values.
|
||||
uint32_t mask_;
|
||||
};
|
||||
|
||||
@ -28,9 +28,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -38,16 +38,18 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int> {
|
||||
template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
|
||||
struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int> {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<4, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// Shape of the thread-level GEMM (K-by-N-by-M)
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef int8_t ScalarA;
|
||||
/// The fragment for A.
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -82,6 +82,11 @@ struct IgemmSwizzle {
|
||||
int a2 = src_int[i2];
|
||||
int a3 = src_int[i3];
|
||||
|
||||
// // DEBUG.
|
||||
// if (threadIdx.x == 0) {
|
||||
// printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3);
|
||||
// }
|
||||
|
||||
int b0, b1, b2, b3, c0;
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
@ -99,6 +104,11 @@ struct IgemmSwizzle {
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
|
||||
asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
|
||||
|
||||
// // DEBUG.
|
||||
// if (threadIdx.x == 0) {
|
||||
// printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3);
|
||||
// }
|
||||
|
||||
dst_int[i0] = b0;
|
||||
dst_int[i1] = b1;
|
||||
dst_int[i2] = b2;
|
||||
|
||||
@ -29,18 +29,18 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/igemm_epilogue.h>
|
||||
#include <cutlass/gemm/igemm_global_tile.h>
|
||||
#include <cutlass/gemm/igemm_multiply_add.h>
|
||||
#include <cutlass/gemm/igemm_swizzle.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/igemm_epilogue.h"
|
||||
#include "cutlass/gemm/igemm_global_tile.h"
|
||||
#include "cutlass/gemm/igemm_multiply_add.h"
|
||||
#include "cutlass/gemm/igemm_swizzle.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -52,49 +52,52 @@ template <
|
||||
typename OutputTile_,
|
||||
/// The output type.
|
||||
typename ScalarD_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
ScalarD_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_>
|
||||
struct IgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
/// The scalar type for B.
|
||||
int8_t,
|
||||
/// The scalar type for C.
|
||||
ScalarD_,
|
||||
/// The scalar type for D.
|
||||
ScalarD_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
4,
|
||||
/// The number of scalars per LDS for A.
|
||||
16,
|
||||
/// The number of scalars per LDG for B.
|
||||
4,
|
||||
/// The number of scalars per STS for B.
|
||||
4,
|
||||
/// The number of scalars per LDS for B.
|
||||
16,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
false,
|
||||
/// kLaunchBounds
|
||||
false> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutputTile_, typename AccumulatorsPerThread_>
|
||||
struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
template <typename OutputTile_, typename ThreadGemmShape_>
|
||||
struct IgemmConfig<OutputTile_, int8_t, ThreadGemmShape_>
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
int8_t,
|
||||
@ -107,7 +110,7 @@ struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
|
||||
/// The number of scalars per LDG for A.
|
||||
4,
|
||||
/// The number of scalars per STS for A.
|
||||
@ -128,8 +131,12 @@ struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
|
||||
4,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// Enable the code path that deals with the residue in epilogue.
|
||||
true> {};
|
||||
/// If true, separate mainloop is instantiated from residue
|
||||
false,
|
||||
/// Compute residue in prolog?
|
||||
true,
|
||||
/// Launch bounds?
|
||||
false> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -162,7 +169,7 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
/// The global load iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
@ -208,7 +215,7 @@ struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
|
||||
GemmConfig_::kScalarsPerLdgA>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
/// The global load iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
@ -281,7 +288,7 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
/// The global load iterator.
|
||||
typedef IgemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
@ -345,7 +352,7 @@ struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
|
||||
GemmConfig_::kScalarsPerLdgB>
|
||||
GlobalTileTraits;
|
||||
|
||||
// The iterator.
|
||||
/// The global load iterator.
|
||||
typedef GemmGlobalIteratorAb<GlobalTileTraits, Index_> GlobalLoadIterator;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
@ -404,13 +411,13 @@ template <
|
||||
typename ScalarD_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int>
|
||||
struct IgemmTraitsHelper {
|
||||
/// The IGEMM config.
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, AccumulatorsPerThread_> GemmConfig;
|
||||
typedef IgemmConfig<OutputTile_, ScalarD_, ThreadGemmShape_> GemmConfig;
|
||||
/// The GEMM config for A.
|
||||
typedef IgemmTileTraitsHelperA<kLayoutA_, GemmConfig, Index_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
@ -418,7 +425,6 @@ struct IgemmTraitsHelper {
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
|
||||
|
||||
/// The default transformer for A.
|
||||
typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
|
||||
GlobalLoadIteratorA>::Transformer GlobalTransformerA;
|
||||
@ -429,12 +435,14 @@ struct IgemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
|
||||
|
||||
// The default transformer for B.
|
||||
typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
|
||||
GlobalLoadIteratorB>::Transformer GlobalTransformerB;
|
||||
@ -445,7 +453,10 @@ struct IgemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
@ -501,8 +512,8 @@ template <
|
||||
typename ScalarD_ = int,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<typename IgemmEpilogueScalar<ScalarD_>::Scalar>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<32, 8, 8>,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The helper class.
|
||||
@ -511,7 +522,7 @@ template <
|
||||
OutputTile_,
|
||||
ScalarD_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerThread_,
|
||||
ThreadGemmShape_,
|
||||
Index_> >
|
||||
struct IgemmTraits : public GemmTraits<
|
||||
// The config.
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
@ -27,18 +28,31 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_multiply_add.h>
|
||||
#include "cutlass/fragment_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE bool is_zero(T x) {
|
||||
return x == T(0);
|
||||
}
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_> >
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
|
||||
struct LinearScaling {
|
||||
// The scalar.
|
||||
typedef Scalar_ Scalar;
|
||||
// The accumulator Type
|
||||
typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
|
||||
// The adapater.
|
||||
typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
|
||||
|
||||
@ -47,6 +61,21 @@ struct LinearScaling {
|
||||
/// The alpha/beta scaling params.
|
||||
Scalar alpha, beta;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) {
|
||||
alpha = _alpha;
|
||||
beta = _beta;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
@ -56,14 +85,53 @@ struct LinearScaling {
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
|
||||
CUTLASS_DEVICE LinearScaling() { }
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
|
||||
|
||||
/// Method to determine whether the source accumulator matrix C is ever needed. This method
|
||||
/// may always safely return true, though better performance is possible if the source accumulator
|
||||
/// matrix is never loaded unnecessarily.
|
||||
CUTLASS_DEVICE
|
||||
bool source_required() const {
|
||||
return !is_zero(params.beta);
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
template <typename FragmentA_, typename FragmentB_>
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
mad.multiply(alpha, accum, output);
|
||||
mad.multiply(params.alpha, accum, output);
|
||||
|
||||
}
|
||||
|
||||
/// Evaluate the functor, without using fragment in the API
|
||||
template <typename ScalarAccum, typename ScalarOutput, int size>
|
||||
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
|
||||
Fragment<ScalarAccum, size> FragAccum;
|
||||
Fragment<ScalarOutput, size> FragOutput;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
FragAccum[i] = accum[i];
|
||||
FragOutput[i] = output[i];
|
||||
}
|
||||
evaluate(FragAccum, FragOutput);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = FragOutput[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate the functor.
|
||||
@ -71,12 +139,28 @@ struct LinearScaling {
|
||||
CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
|
||||
FragmentMultiplyAdd mad;
|
||||
FragmentB_ tmp;
|
||||
mad.multiply(beta, old, tmp);
|
||||
mad.multiply_add(alpha, accum, tmp, output);
|
||||
mad.multiply(params.beta, old, tmp);
|
||||
mad.multiply_add(params.alpha, accum, tmp, output);
|
||||
}
|
||||
|
||||
/// The alpha/beta scaling factors.
|
||||
Scalar alpha, beta;
|
||||
/// Evaluate the functor, without using fragment in the API
|
||||
template <typename ScalarAccum, typename ScalarOutput, int size>
|
||||
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
|
||||
Fragment<ScalarAccum, size> FragAccum;
|
||||
Fragment<ScalarOutput, size> FragOutput;
|
||||
Fragment<ScalarOutput, size> FragOld;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
FragAccum[i] = accum[i];
|
||||
FragOutput[i] = output[i];
|
||||
FragOld[i] = old[i];
|
||||
}
|
||||
evaluate(FragAccum, FragOld, FragOutput);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; i++) {
|
||||
output[i] = FragOutput[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
149
cutlass/gemm/linear_scaling_device_ptr.h
Normal file
@ -0,0 +1,149 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/scalar_or_pointer.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor to compute linear combination of fragments. This is intended to support passing scalars
|
||||
/// either by value from the host or by reference to device-side scalar elements. This is inspired
|
||||
/// by cuBLAS's device pointer mode.
|
||||
template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
|
||||
struct LinearScalingDevicePtr : public LinearScaling<Scalar_, FragmentMultiplyAdd_> {
|
||||
|
||||
/// Linear Scaling class used
|
||||
typedef LinearScaling<Scalar_, FragmentMultiplyAdd_> Base;
|
||||
|
||||
// The scalar.
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// The parameters.
|
||||
class Params {
|
||||
private:
|
||||
/// Alpha scalar
|
||||
detail::ScalarOrPointer<Scalar> alpha_;
|
||||
|
||||
/// Beta sclaar
|
||||
detail::ScalarOrPointer<Scalar> beta_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Scalar alpha,
|
||||
Scalar beta
|
||||
):
|
||||
alpha_(alpha),
|
||||
beta_(beta) {}
|
||||
|
||||
// Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(
|
||||
Scalar const *alpha_ptr,
|
||||
Scalar const *beta_ptr
|
||||
):
|
||||
alpha_(alpha_ptr),
|
||||
beta_(alpha_ptr) {}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Scalar alpha,
|
||||
Scalar beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_ = beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Scalar const *alpha,
|
||||
Scalar const *beta) {
|
||||
|
||||
alpha_ = alpha;
|
||||
beta_= beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize the parameters.
|
||||
template <typename GemmDesc_>
|
||||
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
|
||||
|
||||
alpha_ = desc.alpha;
|
||||
beta_ = desc.beta;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Gets the alpha scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar alpha() const {
|
||||
return alpha_;
|
||||
}
|
||||
|
||||
/// Gets the beta scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar beta() const {
|
||||
return beta_;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const& _params) {
|
||||
this->params.alpha = _params.alpha();
|
||||
this->params.beta = _params.beta();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
129
cutlass/gemm/scalar_or_pointer.h
Normal file
@ -0,0 +1,129 @@
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the BLAS linear scaling function alpha*AB + beta*C
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Helper class defines an object which operates as either a scalar or a pointer. If the pointer
|
||||
/// is non-null, it is dereferenced when the object is accessed.
|
||||
template <typename Scalar_>
|
||||
class ScalarOrPointer {
|
||||
public:
|
||||
/// Underlying scalar type
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Scalar value
|
||||
Scalar scalar;
|
||||
|
||||
/// Pointer to use if non null
|
||||
Scalar const *ptr;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(): scalar(0), ptr(nullptr) {}
|
||||
|
||||
/// Object behaves as a scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(Scalar const &val): scalar(val), ptr(nullptr) {}
|
||||
|
||||
/// Object behaves as a scalar
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer(Scalar const *ptr_): scalar(0), ptr(ptr_) {}
|
||||
|
||||
/// Returns true if is pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool is_pointer() const {
|
||||
return bool(ptr);
|
||||
}
|
||||
|
||||
/// Gets the pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *get_ptr() const {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
/// Gets the pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar get_scalar() const {
|
||||
return scalar;
|
||||
}
|
||||
|
||||
/// Assigns to a scalar and sets pointer to nullptr
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer &operator=(Scalar const &scalar_) {
|
||||
scalar = scalar_;
|
||||
ptr = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Assigns to a pointer value
|
||||
CUTLASS_HOST_DEVICE
|
||||
ScalarOrPointer &operator=(Scalar const *ptr_) {
|
||||
ptr = ptr_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Access the element
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar get() const {
|
||||
if (ptr) {
|
||||
return *ptr;
|
||||
}
|
||||
return scalar;
|
||||
}
|
||||
|
||||
/// Accesses the element
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator Scalar() const {
|
||||
return get();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -27,13 +27,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/thread_multiply_add.h>
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/thread_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -43,46 +43,53 @@ namespace gemm {
|
||||
template <
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of scalars per LDG for B.
|
||||
int kScalarsPerLdgB_ = 1>
|
||||
struct SgemmConfig
|
||||
: public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
float,
|
||||
/// The scalar type for B.
|
||||
float,
|
||||
/// The scalar type for C.
|
||||
float,
|
||||
/// The scalar type for D.
|
||||
float,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, float, float, float>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2> {};
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// Whether to specify launch bounds
|
||||
bool kLaunchBounds = true>
|
||||
struct SgemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
float,
|
||||
/// The scalar type for B.
|
||||
float,
|
||||
/// The scalar type for C.
|
||||
float,
|
||||
/// The scalar type for D.
|
||||
float,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, float, float, float>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per STS for A.
|
||||
kScalarsPerLdgA_,
|
||||
/// The number of scalars per LDS for A.
|
||||
4,
|
||||
/// The number of scalars per LDG for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per STS for B.
|
||||
kScalarsPerLdgB_,
|
||||
/// The number of scalars per LDS for B.
|
||||
4,
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
1,
|
||||
/// The number of scalars per STS for D.
|
||||
4,
|
||||
/// The number of scalars per LDS for D.
|
||||
1,
|
||||
/// The number of stages in shared memory.
|
||||
2,
|
||||
/// kResidueSeparate
|
||||
false,
|
||||
/// kResidueInPrologue
|
||||
true,
|
||||
/// kLaunchBounds
|
||||
kLaunchBounds> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -95,8 +102,8 @@ template <
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// The number of accumulators per thread.
|
||||
typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
@ -105,7 +112,7 @@ template <
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_>,
|
||||
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, false>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
@ -123,5 +130,43 @@ struct SgemmTraits : public SimplifiedGemmTraits<
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to define SGEMM traits using Launch Bounds
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
/// The layout for B.
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_ = Shape<8, 128, 128>,
|
||||
/// The functor to use in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<float>,
|
||||
/// Tile size for thread-level GEMM (K-by-N-by-M)
|
||||
typename ThreadGemmShape_ = Shape<8, 8, 8>,
|
||||
/// The number of floats loaded in one LDG for A.
|
||||
int kScalarsPerLdgA_ = 1,
|
||||
/// The number of floats loaded in one LDG for B.
|
||||
int kScalarsPerLdgB_ = 1,
|
||||
/// The index.
|
||||
typename Index_ = int,
|
||||
/// The SGEMM config.
|
||||
typename GemmConfig_ =
|
||||
SgemmConfig<OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, true>,
|
||||
/// The traits class for the epilogue.
|
||||
typename GemmEpilogueTraits_ =
|
||||
SimplifiedGemmEpilogueTraits<GemmConfig_, EpilogueFunctor_, Index_> >
|
||||
struct SgemmLBTraits : public SimplifiedGemmTraits<
|
||||
// The layout for A.
|
||||
kLayoutA_,
|
||||
// The layout for B.
|
||||
kLayoutB_,
|
||||
// The config.
|
||||
GemmConfig_,
|
||||
// The epilogue.
|
||||
GemmEpilogue<GemmEpilogueTraits_>,
|
||||
// The index.
|
||||
Index_> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -35,20 +35,23 @@ namespace gemm {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template performing matrix multiply-add operation within a thread
|
||||
template <typename AccumulatorsPerThread_,
|
||||
template <typename ThreadGemmShape_,
|
||||
typename ThreadsPerWarp_,
|
||||
typename ScalarA_,
|
||||
typename ScalarB_,
|
||||
typename ScalarC_>
|
||||
typename ScalarC_,
|
||||
MatrixLayout::Kind kLayout_ = MatrixLayout::kColumnMajor>
|
||||
struct ThreadMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<1, 1, 1, 1> InstructionShape;
|
||||
/// The number of accumulators per thread.
|
||||
typedef AccumulatorsPerThread_ AccumulatorsPerThread;
|
||||
/// The shape of a thread-leveel matrix multiply accumulate.
|
||||
typedef ThreadGemmShape_ ThreadGemmShape;
|
||||
/// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
|
||||
typedef ThreadGemmShape AccumulatorsPerThread;
|
||||
/// The number of threads per warp.
|
||||
typedef ThreadsPerWarp_ ThreadsPerWarp;
|
||||
/// The number of accumulators per warp.
|
||||
typedef typename ShapeMul<AccumulatorsPerThread, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
typedef typename ShapeMul<ThreadGemmShape, ThreadsPerWarp>::Shape AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The fragment for A.
|
||||
@ -70,9 +73,18 @@ struct ThreadMultiplyAdd {
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
if(kLayout_ == MatrixLayout::kColumnMajor) {
|
||||
for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for(int i = 0; i < AccumulatorsPerThread::kW; ++i) {
|
||||
for(int j = 0; j < AccumulatorsPerThread::kH; ++j) {
|
||||
d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
387
cutlass/gemm/threadblock_swizzle.h
Normal file
@ -0,0 +1,387 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
struct swizzleDirection {
|
||||
enum Kind { Boustrophedon, OneDirection };
|
||||
};
|
||||
// helper template function
|
||||
template <enum swizzleDirection::Kind>
|
||||
CUTLASS_DEVICE int getLinearIdx(int groups) {
|
||||
// groupCols is not needed for OneDirection Swizzle
|
||||
return blockIdx.y * gridDim.x + blockIdx.x;
|
||||
}
|
||||
template <>
|
||||
CUTLASS_DEVICE int getLinearIdx<swizzleDirection::Boustrophedon>(int groups) {
|
||||
// reverse blockIdx.x for some columns
|
||||
if ((blockIdx.y / groups) % 2 == 1)
|
||||
return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1);
|
||||
else
|
||||
return blockIdx.y * gridDim.x + blockIdx.x;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*!@defgroup IdentityBlockSwizzle Identity Block Swizzle
|
||||
@{
|
||||
Block Swizzle provides the mapping logic between a block in the physical memory of Matrix C and
|
||||
Thread Block
|
||||
Identiy Block Swizzle effective maps blocks in leading dimension order (column major) with
|
||||
thread block
|
||||
in leading dimension order (blockIdx.x)
|
||||
blockIdx.z is mapped with batch_count for batched GEMM
|
||||
@}
|
||||
*/
|
||||
struct IdentityBlockSwizzle {
|
||||
/// Ctor. aka ColumnMajorBlockSwizzle<1>
|
||||
CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
/*OutputTile and problem_size are both in KNM order*/
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle
|
||||
groupCols has the effect of controlling the schedulling of thread blocks
|
||||
settings with different groupCols can contribute to the overall performance by affecting L2 cache
|
||||
hit rate
|
||||
|
||||
consider a regular thread block mapping btween matrix C and different thread blocks
|
||||
note that C is column major, and the leading dimension of thread block id is blockIdx.x
|
||||
|
||||
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
|
||||
(blockIdx.x, blockIdx.y)
|
||||
mapping between threadblockID and C matrix:
|
||||
-------------------------------------------------------
|
||||
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are
|
||||
launched in a column major
|
||||
|
||||
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little,
|
||||
-------------------------------------------------------
|
||||
(0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
so in memory, it would apprear that we work on 2 columns at a time rather than 1
|
||||
Note that the index here really represent how each block maps to memory
|
||||
|
||||
A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection>
|
||||
except that every column flips the ordering against the previous one
|
||||
-------------------------------------------------------
|
||||
(0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like
|
||||
-------------------------------------------------------
|
||||
(0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
(0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
*/
|
||||
|
||||
template <int groupCols, enum swizzleDirection::Kind swDirection>
|
||||
struct ColumnMajorBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() {
|
||||
assert(gridDim.z == 1);
|
||||
int linearIdx = getLinearIdx<swDirection>(groupCols);
|
||||
dim3 swizzledBlockIdx;
|
||||
int currGroupCols = groupCols;
|
||||
int prevGroupCols = groupCols;
|
||||
|
||||
if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) {
|
||||
// last colmuns if gridDim.y is not divisble by groupCols
|
||||
currGroupCols = gridDim.y % groupCols;
|
||||
}
|
||||
|
||||
swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x;
|
||||
swizzledBlockIdx.y =
|
||||
linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x));
|
||||
swizzledBlockIdx.z = blockIdx.z;
|
||||
|
||||
return swizzledBlockIdx;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
|
||||
consider a regular thread block mapping btween matrix C and different thread blocks
|
||||
note that C is column major, and the leading dimension of thread block id is blockIdx.x
|
||||
|
||||
let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1
|
||||
(blockIdx.x, blockIdx.y)
|
||||
mapping between threadblockID and C matrix:
|
||||
-------------------------------------------------------
|
||||
(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) |
|
||||
-------------------------------------------------------
|
||||
(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) |
|
||||
-------------------------------------------------------
|
||||
(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) |
|
||||
-------------------------------------------------------
|
||||
(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) |
|
||||
-------------------------------------------------------
|
||||
(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) |
|
||||
-------------------------------------------------------
|
||||
(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) |
|
||||
-------------------------------------------------------
|
||||
|
||||
A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map
|
||||
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) |
|
||||
---------------------------------------------
|
||||
(0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) |
|
||||
-----------------------------------------------
|
||||
(0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
It would aprear in memory we are working on 1 row at a time
|
||||
|
||||
A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) |
|
||||
---------------------------------------------
|
||||
(1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
so in memory, it would apprear that we work on 2 rows at a time rather than 1 row
|
||||
Note that the index here really represent how each block maps to memory
|
||||
|
||||
A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection>
|
||||
except that every column flips the ordering against the previous one
|
||||
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) |
|
||||
---------------------------------------------
|
||||
(0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) |
|
||||
-----------------------------------------------
|
||||
(0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) |
|
||||
-----------------------------------------------
|
||||
|
||||
similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like
|
||||
-----------------------------------------------
|
||||
(0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) |
|
||||
-----------------------------------------------
|
||||
(1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) |
|
||||
-----------------------------------------------
|
||||
(0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) |
|
||||
-----------------------------------------------
|
||||
(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) |
|
||||
-----------------------------------------------
|
||||
(0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) |
|
||||
---------------------------------------------
|
||||
(1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) |
|
||||
-----------------------------------------------
|
||||
(0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) |
|
||||
-----------------------------------------------
|
||||
|
||||
*/
|
||||
|
||||
template <int groupRows, enum swizzleDirection::Kind swDirection>
|
||||
struct RowMajorBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_HOST_DEVICE RowMajorBlockSwizzle() {}
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() {
|
||||
assert(gridDim.z == 1);
|
||||
int linearIdx = getLinearIdx<swDirection>(groupRows);
|
||||
dim3 swizzledBlockIdx;
|
||||
int currGroupRows = groupRows;
|
||||
int prevGroupRows = groupRows;
|
||||
|
||||
if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) {
|
||||
// last columns
|
||||
currGroupRows = gridDim.y % groupRows;
|
||||
}
|
||||
|
||||
swizzledBlockIdx.x =
|
||||
linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x));
|
||||
swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x;
|
||||
swizzledBlockIdx.z = blockIdx.z;
|
||||
|
||||
return swizzledBlockIdx;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size,
|
||||
Coord<3> const &OutputTile) {
|
||||
dim3 grid;
|
||||
grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1];
|
||||
grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2];
|
||||
grid.z = problem_size.batch();
|
||||
return grid;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) {
|
||||
dim3 block = swizzle();
|
||||
Coord<3> threadblock_offset =
|
||||
make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]);
|
||||
return threadblock_offset;
|
||||
}
|
||||
|
||||
///
|
||||
CUTLASS_DEVICE int get_batch_id() {
|
||||
dim3 block = swizzle();
|
||||
return block.z;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -27,18 +27,18 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/gemm/gemm_global_stream.h>
|
||||
#include <cutlass/gemm/gemm_shared_stream.h>
|
||||
#include <cutlass/gemm/linear_scaling.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_shared_tile.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/gemm/gemm_global_stream.h"
|
||||
#include "cutlass/gemm/gemm_shared_stream.h"
|
||||
#include "cutlass/gemm/linear_scaling.h"
|
||||
#include "cutlass/gemm/wmma_gemm_global_tile.h"
|
||||
#include "cutlass/gemm/wmma_gemm_shared_tile.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -89,7 +89,7 @@ struct WmmaGemmEpilogueTraitsHelper {
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
FragmentElementType::kWmmaMatrix>
|
||||
SharedStoreIteratorD;
|
||||
|
||||
/// The shared store transformer for D.
|
||||
@ -114,6 +114,9 @@ struct WmmaGemmEpilogueTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedLoadIteratorD;
|
||||
|
||||
/// The stream to load D.
|
||||
typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for C^N.
|
||||
typedef WmmaGemmGlobalIteratorCdTraits<
|
||||
// The pointer is float const.
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -68,22 +68,13 @@ struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand:
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename TileTraits_, typename Index_ = int>
|
||||
struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_> {
|
||||
struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index_> {
|
||||
/// This class.
|
||||
typedef WmmaGemmGlobalIteratorCd<TileTraits_, Index_> This_;
|
||||
/// The traits.
|
||||
typedef TileTraits_ Traits;
|
||||
/// The base class.
|
||||
typedef TileIteratorBase<Traits,
|
||||
typename TileTraits_::Scalar,
|
||||
IteratorAdvance::kH,
|
||||
MemorySpace::kGlobal,
|
||||
Index_>
|
||||
Base;
|
||||
typedef GemmGlobalIteratorCd<Traits, Index_> Base;
|
||||
/// Override the strides in each dimension between different loads/stores.
|
||||
typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides;
|
||||
/// The layout.
|
||||
@ -99,47 +90,36 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
typedef Index_ Index;
|
||||
/// The thread offset functor.
|
||||
typedef typename TileTraits_::ThreadOffset ThreadOffset;
|
||||
/// Base parameters.
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// The params.
|
||||
struct Params {
|
||||
/// The pointer.
|
||||
Pointer pointer;
|
||||
/// The stride in the H dimension to setup the thread in the block.
|
||||
Index stride_h;
|
||||
/// The strides to increment the pointer.
|
||||
Index inc_h, inc_advance;
|
||||
/// The column offset to compute the predicate for the columns.
|
||||
Index predicate_offset;
|
||||
/// The strides to increment the predicate offset.
|
||||
Index predicate_inc_h, predicate_inc_advance;
|
||||
|
||||
struct Params : public BaseParams {
|
||||
/// Setup the params.
|
||||
CUTLASS_HOST_DEVICE int initialize(
|
||||
Pointer pointer, Index ld, Index n, Index epilogue_stride_w, Index epilogue_delta_w) {
|
||||
CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
|
||||
long long batch_stride,
|
||||
Index ldm,
|
||||
Index n,
|
||||
Index epilogue_stride_w,
|
||||
Index epilogue_delta_w) {
|
||||
// The pointer.
|
||||
this->pointer = pointer;
|
||||
BaseParams::pointer = pointer;
|
||||
// Stride between GEMMs
|
||||
BaseParams::stride_d = batch_stride;
|
||||
// Setup the base stride. One "group of threads" per column.
|
||||
stride_h = ld;
|
||||
BaseParams::stride_h = ldm;
|
||||
// Each thread output 1 column per iteration. .
|
||||
inc_h = ld * TileTraits_::Threads::kH;
|
||||
inc_advance = inc_h + epilogue_stride_w;
|
||||
BaseParams::inc_h = ldm * TileTraits_::Threads::kH;
|
||||
BaseParams::inc_advance = BaseParams::inc_h + epilogue_stride_w;
|
||||
|
||||
predicate_offset = n;
|
||||
predicate_inc_h = TileTraits_::Threads::kH;
|
||||
predicate_inc_advance = predicate_inc_h + epilogue_delta_w;
|
||||
BaseParams::predicate_offset = n;
|
||||
BaseParams::predicate_inc_h = TileTraits_::Threads::kH;
|
||||
BaseParams::predicate_inc_advance = BaseParams::predicate_inc_h + epilogue_delta_w;
|
||||
|
||||
// It worked.
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
Params params;
|
||||
|
||||
Coord<4> thread_offset;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
|
||||
const Coord<3>& bounds,
|
||||
@ -148,61 +128,37 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase<TileTraits_,
|
||||
int const pred_offset = 0,
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
|
||||
: params(params) {
|
||||
thread_offset = thread_offset_func();
|
||||
// Each warp works on a different column of the tile.
|
||||
int const h = thread_offset[1] + block[1];
|
||||
// Each lane writes a different element.
|
||||
int const w = thread_offset[2] + block[2];
|
||||
// Setup the pointer.
|
||||
this->params.pointer += ((h * params.stride_h + w) + pointer_offset);
|
||||
: Base(params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {}
|
||||
|
||||
// Prepare the vector of predicates.
|
||||
for (int i = 0; i < Base::Iterations::kW; ++i) {
|
||||
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
|
||||
}
|
||||
this->params.predicate_offset -= (h + pred_offset);
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_DEVICE void load_element(
|
||||
typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
Base::load_element(value, d, h, w, c);
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(0, 0, w, c);
|
||||
Load<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::load(value, params.pointer, imm);
|
||||
}
|
||||
|
||||
/// Increment the pointer in the C dimension.
|
||||
CUTLASS_DEVICE void inc_c() {}
|
||||
/// Increment the pointer in the W dimension.
|
||||
CUTLASS_DEVICE void inc_w() {}
|
||||
/// Increment the pointer in the H dimension.
|
||||
CUTLASS_DEVICE void inc_h() {
|
||||
params.pointer += params.inc_h;
|
||||
params.predicate_offset -= params.predicate_inc_h;
|
||||
}
|
||||
/// Increment the pointer in the D dimension.
|
||||
CUTLASS_DEVICE void inc_d() {}
|
||||
/// Increment the pointer to move to the next iteration.
|
||||
CUTLASS_DEVICE void inc_advance() {
|
||||
params.pointer += params.inc_advance;
|
||||
params.predicate_offset -= params.predicate_inc_advance;
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
/// Stores a single fragment element into memory
|
||||
CUTLASS_DEVICE void store_element(
|
||||
typename Base::AccessType const& value, int d, int h, int w, int c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, 0);
|
||||
Store<Scalar, TileTraits_::kAccessSize, MemorySpace::kGlobal>::store(
|
||||
value, params.pointer, imm);
|
||||
Store<Scalar,
|
||||
Base::kAccessSize,
|
||||
Base::kMemorySpace,
|
||||
Base::kFragmentElementType,
|
||||
typename Base::FragmentElement,
|
||||
Base::Tile::kW>::store(value, Base::params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Test the predicate.
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
|
||||
return predicates.at(w) && params.predicate_offset > 0;
|
||||
public:
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment& fragment) {
|
||||
Base::load_post_increment(fragment);
|
||||
}
|
||||
|
||||
/// The predicates for the row.
|
||||
cutlass::PredicateVector<Base::Iterations::kW> predicates;
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store_post_increment(Fragment& fragment) {
|
||||
Base::store_post_increment(fragment);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -27,9 +27,9 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
#include <cutlass/fragment.h>
|
||||
#include "cutlass/fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -42,15 +42,17 @@ template <MatrixLayout::Kind kLayoutA_,
|
||||
typename ScalarB_,
|
||||
MatrixLayout::Kind kLayoutC_,
|
||||
typename ScalarC_,
|
||||
typename AccumulatorsPerWarp_,
|
||||
typename WarpGemmShape_,
|
||||
typename InstructionShape_>
|
||||
struct WmmaGemmMultiplyAdd {
|
||||
/// The shape of the instruction.
|
||||
typedef InstructionShape_ InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
|
||||
/// The dimensions.
|
||||
typedef AccumulatorsPerWarp_ AccumulatorsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef ScalarA_ ScalarA;
|
||||
/// The type for B.
|
||||
@ -102,6 +104,251 @@ struct WmmaGemmMultiplyAdd {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with binary operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<128, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<128, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<bin1_t, 32> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<bin1_t, 32> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::bmma_sync(elt_d,
|
||||
elt_a,
|
||||
elt_b,
|
||||
elt_c,
|
||||
nvcuda::wmma::experimental::bmmaBitOpXOR,
|
||||
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with signed 4-bit integer operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<int4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<int4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<32, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<32, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<int4_t, 8> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<int4_t, 8> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<int4_t, 8>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<int4_t, 8>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with unsigned 4-bit integer operands
|
||||
template<typename WarpGemmShape_>
|
||||
struct WmmaGemmMultiplyAdd <MatrixLayout::kRowMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
WarpGemmShape_,
|
||||
Shape<32, 8, 8> >{
|
||||
/// The shape of the instruction.
|
||||
typedef Shape<32, 8, 8> InstructionShape;
|
||||
/// The number of threads per warp. That's a dummy configuration.
|
||||
typedef Shape<1, 4, 8> ThreadsPerWarp;
|
||||
/// Dimensions of the warp-level GEMM (K-by-N-by-M)
|
||||
typedef WarpGemmShape_ WarpGemmShape;
|
||||
/// Aliased for compatibility. Will be removed in CUTLASS v2.0
|
||||
typedef WarpGemmShape_ AccumulatorsPerWarp;
|
||||
/// The type for A.
|
||||
typedef Vector<uint4_t, 8> ScalarA;
|
||||
/// The type for B.
|
||||
typedef Vector<uint4_t, 8> ScalarB;
|
||||
/// The type for C and D.
|
||||
typedef int ScalarC;
|
||||
/// The number of iterations.
|
||||
typedef typename ShapeDiv<AccumulatorsPerWarp, InstructionShape>::Shape Iterations;
|
||||
|
||||
/// The element for A.
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
InstructionShape> ElementA;
|
||||
/// The fragment for A.
|
||||
typedef Fragment<ElementA, Iterations::kW> FragmentA;
|
||||
|
||||
/// The element for B.
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
InstructionShape> ElementB;
|
||||
/// The fragment for B.
|
||||
typedef Fragment<ElementB, Iterations::kH> FragmentB;
|
||||
|
||||
/// The element for C.
|
||||
typedef WmmaMatrix<GemmOperand::kC,
|
||||
MatrixLayout::kColumnMajor,
|
||||
int,
|
||||
InstructionShape> ElementC;
|
||||
/// The fragment for C.
|
||||
typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
|
||||
|
||||
/// Multiply : d = a*b.
|
||||
CUTLASS_DEVICE void multiply_add(FragmentA const& a,
|
||||
FragmentB const& b,
|
||||
Accumulators const& c,
|
||||
Accumulators& d) {
|
||||
for (int j = 0; j < Iterations::kH; ++j) {
|
||||
for (int i = 0; i < Iterations::kW; ++i) {
|
||||
// The input elements.
|
||||
ElementA const& elt_a = a[i];
|
||||
ElementB const& elt_b = b[j];
|
||||
ElementC const& elt_c = c[j * Iterations::kW + i];
|
||||
|
||||
// The output element.
|
||||
ElementC& elt_d = d[j * Iterations::kW + i];
|
||||
|
||||
// The wmma instruction.
|
||||
nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
|
||||
@ -28,18 +28,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/gemm/gemm_operand.h>
|
||||
#include <cutlass/reshape_tile.h>
|
||||
#include "cutlass/gemm/gemm_operand.h"
|
||||
#include "cutlass/reshape_tile.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
template <class>
|
||||
struct Debug {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MatrixLayout::Kind kLayout_,
|
||||
|
||||
@ -27,19 +27,19 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/wmma_matrix.h>
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_WMMA_API
|
||||
|
||||
#include <cutlass/convert.h>
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/gemm_epilogue.h>
|
||||
#include <cutlass/gemm/gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/gemm_global_tile.h>
|
||||
#include <cutlass/gemm/gemm_shared_tile.h>
|
||||
#include <cutlass/gemm/gemm_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_epilogue_traits.h>
|
||||
#include <cutlass/gemm/wmma_gemm_global_tile.h>
|
||||
#include <cutlass/gemm/wmma_gemm_multiply_add.h>
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/gemm_epilogue.h"
|
||||
#include "cutlass/gemm/gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/gemm_global_tile.h"
|
||||
#include "cutlass/gemm/gemm_shared_tile.h"
|
||||
#include "cutlass/gemm/gemm_traits.h"
|
||||
#include "cutlass/gemm/wmma_gemm_epilogue_traits.h"
|
||||
#include "cutlass/gemm/wmma_gemm_global_tile.h"
|
||||
#include "cutlass/gemm/wmma_gemm_multiply_add.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@ -53,12 +53,16 @@ template <
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_,
|
||||
/// The input type.
|
||||
typename ScalarA_,
|
||||
/// The input type.
|
||||
typename ScalarB_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of scalars per LDG for A.
|
||||
@ -67,9 +71,9 @@ template <
|
||||
int kScalarsPerLdgB_>
|
||||
struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The scalar type for A.
|
||||
half,
|
||||
ScalarA_,
|
||||
/// The scalar type for B.
|
||||
half,
|
||||
ScalarB_,
|
||||
/// The scalar type for C.
|
||||
ScalarC_,
|
||||
/// The scalar type for D.
|
||||
@ -78,12 +82,12 @@ struct WmmaGemmConfig : public GemmConfig<
|
||||
OutputTile_,
|
||||
/// The functor to do the math in the main loop.
|
||||
WmmaGemmMultiplyAdd<kLayoutA_,
|
||||
half,
|
||||
ScalarA_,
|
||||
kLayoutB_,
|
||||
half,
|
||||
ScalarB_,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
WarpGemmShape_,
|
||||
InstructionShape_>,
|
||||
/// The number of scalars per LDG for A.
|
||||
kScalarsPerLdgA_,
|
||||
@ -100,21 +104,29 @@ struct WmmaGemmConfig : public GemmConfig<
|
||||
/// The number of scalars per LDG for C and STG for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
/// The number of scalars per STS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
16 / sizeof(Accumulator_),
|
||||
/// The number of scalars per LDS for D.
|
||||
16 / sizeof(ScalarC_),
|
||||
16 / sizeof(Accumulator_),
|
||||
/// The number of stages in shared memory.
|
||||
1> {};
|
||||
1,
|
||||
/// If true, residue is computed in mainloop. If false, separate loops are instantiated.
|
||||
false,
|
||||
/// Is residue performed in prologue?
|
||||
true,
|
||||
/// If true, kernel is launched with CUDA launch bounds specified
|
||||
false> {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
template <enum MatrixLayout::Kind kLayout_,
|
||||
typename GemmConfig_,
|
||||
typename ScalarA_>
|
||||
struct WmmaGemmTileTraitsHelperA {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
template <typename GemmConfig_, typename ScalarA_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, ScalarA_>
|
||||
: public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> Base;
|
||||
@ -173,8 +185,8 @@ struct WmmaGemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
template <typename GemmConfig_, typename ScalarA_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, ScalarA_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
@ -251,13 +263,276 @@ struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with binary operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<bin1_t, 32> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'bits'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kBitsPerScalar' to get the number in 'Scalar'.
|
||||
static int const kBitsPerScalar = sizeof(Scalar) * 8;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
|
||||
GemmConfig_::OutputTile::kD / kBitsPerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA / kBitsPerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kW,
|
||||
GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA / kBitsPerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with unsigned 4-bit integer operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<uint4_t, 8> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kInt4PerScalar' to get the number in 'Scalar'.
|
||||
static int const kInt4PerScalar = sizeof(Scalar) * 2;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kW,
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with signed 4-bit integer operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Vector<int4_t, 8> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarA Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kInt4PerScalar' to get the number in 'Scalar'.
|
||||
static int const kInt4PerScalar = sizeof(Scalar) * 2;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kA,
|
||||
MatrixLayout::kRowMajor,
|
||||
Vector<int4_t, 8>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for A^T.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's A.
|
||||
GemmOperand::kA,
|
||||
// A is row-major.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgA / kInt4PerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kW,
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for A^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsA / kInt4PerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW;
|
||||
/// The traits class to build the iterator to load from shared memory for A.
|
||||
typedef WmmaGemmSharedLoadTileATraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kRowMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kW * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <enum MatrixLayout::Kind kLayout_,
|
||||
typename GemmConfig_,
|
||||
typename ScalarB_>
|
||||
struct WmmaGemmTileTraitsHelperB {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
template <typename GemmConfig_, typename ScalarB_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, ScalarB_>
|
||||
: public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
|
||||
/// The base config.
|
||||
typedef GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> Base;
|
||||
@ -316,8 +591,8 @@ struct WmmaGemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
template <typename GemmConfig_, typename ScalarB_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, ScalarB_> {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
@ -394,6 +669,267 @@ struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with binary operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<bin1_t, 32> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'bits'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kBitsPerScalar' to get the number in 'Scalar'.
|
||||
static int const kBitsPerScalar = sizeof(Scalar) * 8;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<bin1_t, 32>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// A is row-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar),
|
||||
GemmConfig_::OutputTile::kD / kBitsPerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB / kBitsPerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kH,
|
||||
GemmConfig_::OutputTile::kD / kBitsPerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB / kBitsPerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kBitsPerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with unsigned 4-bit integer operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<uint4_t, 8> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kInt4PerScalar' to get the number in 'Scalar'.
|
||||
static int const kInt4PerScalar = sizeof(Scalar) * 2;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<uint4_t, 8>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// A is row-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kH,
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Specialization for WMMA GEMM with signed 4-bit integer operands
|
||||
template <typename GemmConfig_>
|
||||
struct WmmaGemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Vector<int4_t, 8> > {
|
||||
/// The layout.
|
||||
static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor;
|
||||
|
||||
/// The input scalar.
|
||||
typedef typename GemmConfig_::ScalarB Scalar;
|
||||
/// The scalar stored in shared memory.
|
||||
typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
|
||||
|
||||
/// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'.
|
||||
/// Divide by 'kInt4PerScalar' to get the number in 'Scalar'.
|
||||
static int const kInt4PerScalar = sizeof(Scalar) * 2;
|
||||
|
||||
/// WMMA matrix
|
||||
typedef WmmaMatrix<GemmOperand::kB,
|
||||
MatrixLayout::kColumnMajor,
|
||||
Vector<int4_t, 8>,
|
||||
typename GemmConfig_::InstructionShape>
|
||||
WmmaMatrix;
|
||||
|
||||
/// The traits class to build the iterator to load data from global memory for B^N.
|
||||
typedef GemmGlobalTileTraits<
|
||||
// That's B.
|
||||
GemmOperand::kB,
|
||||
// A is row-major.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer is float const.
|
||||
Scalar const,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
Shape<1,
|
||||
GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar),
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar>,
|
||||
// The number of scalars per LDG (LDG.32 or LDG.128, etc).
|
||||
GemmConfig_::kScalarsPerLdgB / kInt4PerScalar>
|
||||
GlobalTileTraits;
|
||||
|
||||
/// The skew.
|
||||
static int const kSkew = 16 / sizeof(MultiplyAddScalar);
|
||||
/// The tile.
|
||||
typedef Shape<GemmConfig_::kStages,
|
||||
GemmConfig_::OutputTile::kH,
|
||||
GemmConfig_::OutputTile::kD / kInt4PerScalar + kSkew>
|
||||
Tile;
|
||||
|
||||
/// The traits class to build the iterator to store data to shared memory for B^N.
|
||||
typedef GemmSharedStoreTileAbTraits<
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile has size KxM in GEMM's terminology.
|
||||
Tile,
|
||||
// The threads are distributed as warps x 32 (the traits may reorganize).
|
||||
typename GlobalTileTraits::Threads,
|
||||
// The number of scalars per STS (STS.32 or STS.128, etc).
|
||||
GemmConfig_::kScalarsPerStsB / kInt4PerScalar>
|
||||
SharedStoreTileTraits;
|
||||
|
||||
/// The number of elements loaded in one LDG.
|
||||
static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH;
|
||||
/// The traits class to build the iterator to load from shared memory for B.
|
||||
typedef WmmaGemmSharedLoadTileBTraits<
|
||||
// The layout of the matrix.
|
||||
MatrixLayout::kColumnMajor,
|
||||
// The pointer.
|
||||
MultiplyAddScalar,
|
||||
// The tile in shared memory.
|
||||
Tile,
|
||||
// The number of warps.
|
||||
typename GemmConfig_::Warps,
|
||||
// The strides between warps.
|
||||
GemmConfig_::InstructionShape::kH * Tile::kW,
|
||||
// The number of iterations to load the data.
|
||||
Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>,
|
||||
// The stride between iterations.
|
||||
Shape<GemmConfig_::InstructionShape::kD / kInt4PerScalar, 0, kScalarsPerW * Tile::kW>,
|
||||
// The shape of the instruction.
|
||||
typename GemmConfig_::InstructionShape>
|
||||
SharedLoadTileTraits;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// The layout for A.
|
||||
MatrixLayout::Kind kLayoutA_,
|
||||
@ -401,14 +937,18 @@ template <
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The output tile.
|
||||
typename OutputTile_,
|
||||
/// The input type.
|
||||
typename ScalarA_,
|
||||
/// The input type.
|
||||
typename ScalarB_,
|
||||
/// The output type.
|
||||
typename ScalarC_,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_,
|
||||
/// The number of halfs loaded in one LDG for A.
|
||||
@ -422,18 +962,20 @@ struct WmmaGemmTraitsHelper {
|
||||
typedef WmmaGemmConfig<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarA_,
|
||||
ScalarB_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
AccumulatorsPerWarp_,
|
||||
WarpGemmShape_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_>
|
||||
GemmConfig;
|
||||
|
||||
/// The GEMM config for A.
|
||||
typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig> GemmTileTraitsHelperA;
|
||||
typedef WmmaGemmTileTraitsHelperA<kLayoutA_, GemmConfig, ScalarA_> GemmTileTraitsHelperA;
|
||||
/// The GEMM config for B.
|
||||
typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig> GemmTileTraitsHelperB;
|
||||
typedef WmmaGemmTileTraitsHelperB<kLayoutB_, GemmConfig, ScalarB_> GemmTileTraitsHelperB;
|
||||
|
||||
/// The iterator to load A from global memory.
|
||||
typedef GemmGlobalIteratorAb<typename GemmTileTraitsHelperA::GlobalTileTraits, Index_>
|
||||
@ -447,7 +989,10 @@ struct WmmaGemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorA;
|
||||
/// The stream to load A from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA>
|
||||
typedef GlobalLoadStream<GemmOperand::kA,
|
||||
GlobalLoadIteratorA,
|
||||
SharedStoreIteratorA,
|
||||
GlobalTransformerA>
|
||||
GlobalLoadStreamA;
|
||||
|
||||
/// The iterator to load B from global memory.
|
||||
@ -462,7 +1007,10 @@ struct WmmaGemmTraitsHelper {
|
||||
MemorySpace::kShared>
|
||||
SharedStoreIteratorB;
|
||||
/// The stream to load B from global memory to shared memory.
|
||||
typedef GlobalLoadStream<GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB>
|
||||
typedef GlobalLoadStream<GemmOperand::kB,
|
||||
GlobalLoadIteratorB,
|
||||
SharedStoreIteratorB,
|
||||
GlobalTransformerB>
|
||||
GlobalLoadStreamB;
|
||||
|
||||
/// The iterator to load A from shared memory.
|
||||
@ -472,7 +1020,7 @@ struct WmmaGemmTraitsHelper {
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperA::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
FragmentElementType::kWmmaMatrix>
|
||||
SharedLoadIteratorA;
|
||||
/// The stream to load A from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorA> SharedLoadStreamA;
|
||||
@ -483,7 +1031,7 @@ struct WmmaGemmTraitsHelper {
|
||||
MemorySpace::kShared,
|
||||
Index_,
|
||||
typename GemmTileTraitsHelperB::WmmaMatrix,
|
||||
IteratorFragment::kWmmaMatrix>
|
||||
FragmentElementType::kWmmaMatrix>
|
||||
SharedLoadIteratorB;
|
||||
/// The stream to load B from shared memory.
|
||||
typedef SharedLoadStream<SharedLoadIteratorB> SharedLoadStreamB;
|
||||
@ -518,14 +1066,18 @@ template <
|
||||
MatrixLayout::Kind kLayoutB_,
|
||||
/// The tile size for the GEMM KxNxM.
|
||||
typename OutputTile_ = Shape<64, 128, 128>,
|
||||
/// The input type.
|
||||
typename ScalarA_ = half,
|
||||
/// The input type.
|
||||
typename ScalarB_ = half,
|
||||
/// The output type.
|
||||
typename ScalarC_ = float,
|
||||
/// The functor to do the math in the epilogue.
|
||||
typename EpilogueFunctor_ = LinearScaling<ScalarC_>,
|
||||
/// The accumulator type.
|
||||
typename Accumulator_ = ScalarC_,
|
||||
/// The number of accumulators per warp.
|
||||
typename AccumulatorsPerWarp_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
|
||||
/// Tile size for warp-level GEMM (K-by-N-by-M)
|
||||
typename WarpGemmShape_ = typename WmmaGemmAccumulatorsPerWarp<OutputTile_>::Shape,
|
||||
/// The shape of the WMMA instruction.
|
||||
typename InstructionShape_ = Shape<16, 16, 16>,
|
||||
/// The number of scalars per LDG for A.
|
||||
@ -538,10 +1090,12 @@ template <
|
||||
typename Helper_ = WmmaGemmTraitsHelper<kLayoutA_,
|
||||
kLayoutB_,
|
||||
OutputTile_,
|
||||
ScalarA_,
|
||||
ScalarB_,
|
||||
ScalarC_,
|
||||
Accumulator_,
|
||||
EpilogueFunctor_,
|
||||
AccumulatorsPerWarp_,
|
||||
WarpGemmShape_,
|
||||
InstructionShape_,
|
||||
kScalarsPerLdgA_,
|
||||
kScalarsPerLdgB_,
|
||||
|
||||
@ -27,16 +27,14 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment_load_store.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include "cutlass/load_store.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
// Used by convolution
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
@ -45,12 +43,12 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.get(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
iterator.load_element(reinterpret_cast<typename InputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < InputIterator::Iterations::kW - 1) {
|
||||
@ -68,138 +66,21 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from a shared memory input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) {
|
||||
typename InputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename InputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentLoad<InputIterator::kIteratorFragment,
|
||||
InputIterator::Tile::kC,
|
||||
typename InputIterator::Scalar,
|
||||
InputIterator::kMemorySpace,
|
||||
typename InputIterator::FragmentElement,
|
||||
InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator, masked by a predicate iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = InputIterator::Tile::kC *
|
||||
(w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
|
||||
|
||||
Load<typename Fragment::Element, InputIterator::Tile::kC, InputIterator::kMemorySpace>::
|
||||
load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
InputIterator iterator(_iterator);
|
||||
iterator_load_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
typename InputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_load(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Loads a fragment from an input iterator
|
||||
template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
|
||||
Fragment &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_load(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
|
||||
typename OutputIterator::FragmentIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
if (iterator.valid(d, h, w, 0)) {
|
||||
iterator.set(reinterpret_cast<typename OutputIterator::AccessType const &>(
|
||||
frag_iterator.at(d, h, w, 0)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
0);
|
||||
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
|
||||
if (iterator.valid(d, h, w, c)) {
|
||||
iterator.store_element(reinterpret_cast<typename OutputIterator::AccessType &>(
|
||||
frag_iterator.at(d, h, w, c)),
|
||||
d,
|
||||
h,
|
||||
w,
|
||||
c);
|
||||
}
|
||||
}
|
||||
if (w < OutputIterator::Iterations::kW - 1) {
|
||||
iterator.inc_w();
|
||||
@ -215,104 +96,6 @@ CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &frag
|
||||
}
|
||||
iterator.inc_advance();
|
||||
}
|
||||
|
||||
/// Stores a fragment to a shared memory output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) {
|
||||
typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
|
||||
for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename OutputIterator::ImmediateOffsetStrides>::get(
|
||||
d, h, w, c);
|
||||
|
||||
FragmentStore<OutputIterator::kIteratorFragment,
|
||||
OutputIterator::Tile::kC,
|
||||
typename OutputIterator::Scalar,
|
||||
OutputIterator::kMemorySpace,
|
||||
typename OutputIterator::FragmentElement,
|
||||
OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
|
||||
for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
|
||||
for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
|
||||
if (predicate_adapter.at(d, h, w, 0)) {
|
||||
int idx = OutputIterator::Tile::kC *
|
||||
(w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
|
||||
|
||||
Store<typename Fragment::Element,
|
||||
OutputIterator::Tile::kC,
|
||||
OutputIterator::kMemorySpace>::
|
||||
store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
|
||||
iterator.data(),
|
||||
offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store_post_increment(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store_post_increment(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator, masked by a predicate iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset,
|
||||
ConstPredicateAdapter predicate_adapter) {
|
||||
OutputIterator iterator(_iterator);
|
||||
iterator_store_post_increment(iterator, fragment, offset, predicate_adapter);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
typename OutputIterator::Index offset = 0) {
|
||||
TrivialPredicateTileAdapter pred;
|
||||
iterator_store(iterator, fragment, offset, pred);
|
||||
}
|
||||
|
||||
/// Stores a fragment to an output iterator
|
||||
template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
|
||||
CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
|
||||
Fragment const &fragment,
|
||||
ConstPredicateAdapter pred_it) {
|
||||
iterator_store(iterator, fragment, 0, pred_it);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
67
cutlass/kernel_launch.h
Normal file
@ -0,0 +1,67 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines structures and helpers to launch CUDA kernels within CUTLASS.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure containing the basic launch configuration of a CUDA kernel.
|
||||
struct KernelLaunchConfiguration {
|
||||
|
||||
/// CUDA grid dimensions
|
||||
dim3 grid;
|
||||
|
||||
/// CUDA threablock dimensions
|
||||
dim3 block;
|
||||
|
||||
/// Bytes of dynamically allocated SMEM in addition to static SMEM
|
||||
size_t dynamic_smem;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs a KernellaunchConfiguration object
|
||||
CUTLASS_HOST_DEVICE
|
||||
KernelLaunchConfiguration(
|
||||
dim3 _grid = dim3(1,1,1),
|
||||
dim3 _block = dim3(1,1,1),
|
||||
size_t _dynamic_smem = 0
|
||||
):
|
||||
grid(_grid),
|
||||
block(_block),
|
||||
dynamic_smem(_dynamic_smem) { }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -27,8 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/vector.h>
|
||||
|
||||
#include "cutlass/vector.h"
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -44,45 +43,68 @@ struct MemorySpace {
|
||||
};
|
||||
};
|
||||
|
||||
/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix
|
||||
struct FragmentElementType {
|
||||
enum Kind { kScalar, kWmmaMatrix };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
int kStride = 1,
|
||||
size_t size = (sizeof(Scalar_) * kAccessSize)>
|
||||
struct Load {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst = reinterpret_cast<AccessType const*>(&pointer[offset])[0];
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst = *reinterpret_cast<AccessType const*>(pointer + offset);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for 16b loads
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
reinterpret_cast<uint16_t&>(dst) = reinterpret_cast<uint16_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
@ -91,13 +113,13 @@ struct Load<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<double, 2, Memory_, true, 16> {
|
||||
template <MemorySpace::Kind Memory_, int kStride>
|
||||
struct Load<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
|
||||
double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
|
||||
dst[0] = tmp.x;
|
||||
dst[1] = tmp.y;
|
||||
@ -108,13 +130,13 @@ struct Load<double, 2, Memory_, true, 16> {
|
||||
|
||||
#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
|
||||
// WAR bug in NVCC where the upper and lower half of the register end up being the same
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Load<half, 8, Memory_, true, 16> {
|
||||
template <MemorySpace::Kind Memory_, int kStride>
|
||||
struct Load<half, 8, Memory_, FragmentElementType::kScalar, half, kStride, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<half, 8>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
|
||||
int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
@ -129,13 +151,13 @@ struct Load<half, 8, Memory_, true, 16> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Load<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
|
||||
uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
|
||||
dst.registers[0] = tmp.x;
|
||||
dst.registers[1] = tmp.y;
|
||||
@ -147,29 +169,45 @@ struct Load<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int Lanes_,
|
||||
int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
bool = (Lanes_ > 1),
|
||||
size_t = (sizeof(Scalar_) * Lanes_)>
|
||||
FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
int kStride = 1,
|
||||
size_t size = (sizeof(Scalar_) * kAccessSize)>
|
||||
struct Store {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<FragmentElement_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
pointer[offset] = src;
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
pointer[offset] = *reinterpret_cast<Scalar_ const*>(&src);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint16_t* addr = reinterpret_cast<uint16_t*>(&pointer[offset]);
|
||||
addr[0] = reinterpret_cast<uint16_t const&>(src);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
|
||||
addr[0] = src.registers[0];
|
||||
}
|
||||
@ -177,13 +215,13 @@ struct Store<Scalar_, Lanes_, Memory_, true, 4> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
|
||||
addr[0] = make_uint2(src.registers[0], src.registers[1]);
|
||||
}
|
||||
@ -191,13 +229,13 @@ struct Store<Scalar_, Lanes_, Memory_, true, 8> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <MemorySpace::Kind Memory_>
|
||||
struct Store<double, 2, Memory_, true, 16> {
|
||||
template <MemorySpace::Kind Memory_, int kStride>
|
||||
struct Store<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<double, 2>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, double* pointer, int offset) {
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, double* pointer, int offset) {
|
||||
double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
|
||||
addr[0] = make_double2(src[0], src[1]);
|
||||
}
|
||||
@ -205,13 +243,13 @@ struct Store<double, 2, Memory_, true, 16> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
|
||||
struct Store<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
|
||||
struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
|
||||
/// The output type.
|
||||
typedef typename Vectorize<Scalar_, Lanes_>::Type AccessType;
|
||||
typedef typename Vectorize<Scalar_, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
|
||||
uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
|
||||
addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
|
||||
}
|
||||
@ -219,4 +257,123 @@ struct Store<Scalar_, Lanes_, Memory_, true, 16> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_,
|
||||
int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride,
|
||||
size_t size>
|
||||
struct Load<Scalar_,
|
||||
kAccessSize,
|
||||
Memory_,
|
||||
FragmentElementType::kWmmaMatrix,
|
||||
FragmentElement_,
|
||||
kStride,
|
||||
size> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
|
||||
value.load(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride,
|
||||
size_t size>
|
||||
struct Load<Vector<bin1_t, 32>,
|
||||
kAccessSize,
|
||||
Memory_,
|
||||
FragmentElementType::kWmmaMatrix,
|
||||
FragmentElement_,
|
||||
kStride,
|
||||
size> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<bin1_t, 32> const* pointer,
|
||||
int offset) {
|
||||
value.load(&pointer[offset], kStride * 32);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride,
|
||||
size_t size>
|
||||
struct Load<Vector<int4_t, 8>,
|
||||
kAccessSize,
|
||||
Memory_,
|
||||
FragmentElementType::kWmmaMatrix,
|
||||
FragmentElement_,
|
||||
kStride,
|
||||
size> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<int4_t, 8> const* pointer,
|
||||
int offset) {
|
||||
value.load(&pointer[offset], kStride * 8);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride,
|
||||
size_t size>
|
||||
struct Load<Vector<uint4_t, 8>,
|
||||
kAccessSize,
|
||||
Memory_,
|
||||
FragmentElementType::kWmmaMatrix,
|
||||
FragmentElement_,
|
||||
kStride,
|
||||
size> {
|
||||
/// The output type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The load function.
|
||||
static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<uint4_t, 8> const* pointer,
|
||||
int offset) {
|
||||
value.load(&pointer[offset], kStride * 8);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename Scalar_,
|
||||
int kAccessSize,
|
||||
MemorySpace::Kind Memory_,
|
||||
typename FragmentElement_,
|
||||
int kStride,
|
||||
size_t size>
|
||||
struct Store<Scalar_,
|
||||
kAccessSize,
|
||||
Memory_,
|
||||
FragmentElementType::kWmmaMatrix,
|
||||
FragmentElement_,
|
||||
kStride,
|
||||
size> {
|
||||
/// The input type.
|
||||
typedef FragmentElement_ AccessType;
|
||||
|
||||
/// The store function.
|
||||
static CUTLASS_HOST_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
|
||||
value.store(&pointer[offset], kStride);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -27,13 +27,327 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Describes layouts of matrices
|
||||
/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes
|
||||
/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.
|
||||
struct MatrixCoord : public Coord<2, int> {
|
||||
|
||||
/// Integer-valued index
|
||||
typedef int Index;
|
||||
|
||||
/// Base type is a Coord of rank=2
|
||||
typedef Coord<2, Index> Base;
|
||||
|
||||
/// Rows dimension
|
||||
static int const kRow = 0;
|
||||
|
||||
/// Columns dimension
|
||||
static int const kColumn = 1;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord() { }
|
||||
|
||||
/// Constructs from Coord<2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord(Coord<2, Index> const &coord): Base(coord) { }
|
||||
|
||||
/// Helper to construct from a row and column
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { }
|
||||
|
||||
/// Returns the row of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & row() const { return this->at(kRow); }
|
||||
|
||||
/// Returns the row of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & row() { return this->at(kRow); }
|
||||
|
||||
/// Returns the column of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & column() const { return this->at(kColumn); }
|
||||
|
||||
/// Returns the column of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & column() { return this->at(kColumn); }
|
||||
|
||||
//
|
||||
// Coord operators
|
||||
//
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord operator+(Base const& b) const {
|
||||
return MatrixCoord(Base::operator+(b));
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord operator-(Base const& b) const {
|
||||
return MatrixCoord(Base::operator-(b));
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord operator*(Base const& b) const {
|
||||
return MatrixCoord(Base::operator*(b));
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord operator/(Base const& b) const {
|
||||
return MatrixCoord(Base::operator/(b));
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord& operator+=(Base const& b) {
|
||||
Base::operator+=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord& operator-=(Base const& b) {
|
||||
Base::operator-=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord& operator*=(Base const& b) {
|
||||
Base::operator*=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
MatrixCoord& operator/=(Base const& b) {
|
||||
Base::operator/=(b);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines data layouts of various matrix formats usable by TensorRef and other classes.
|
||||
//
|
||||
// The following define classes satisfying the TensorRefMapFunc concept. These must support the
|
||||
// following operations, where func is an instance of type TensorRefMapFunc.
|
||||
//
|
||||
// Coord<TensorRefMapFunc::kStorageRank> = func(Coord<kRank>);
|
||||
//
|
||||
// Though not required to be usable by TensorRef, each of the following also define a helper
|
||||
// function to map the "leading dimension" to an appropriate stride vector. Implementations
|
||||
// following this convention should also implement the following static method:
|
||||
//
|
||||
// Coord<TensorRefMapFunc::kStorageRank> stride = TensorRefMapFunc::stride(leading_dim);
|
||||
//
|
||||
struct MatrixLayout {
|
||||
|
||||
/// Enumeration defining fundamental contiguous layouts.
|
||||
enum Kind { kRowMajor, kColumnMajor };
|
||||
|
||||
//
|
||||
// TensorRefMapFunc definitions for common layouts
|
||||
//
|
||||
|
||||
/// Mapping function for row-major matrices
|
||||
struct RowMajor {
|
||||
static int const kStorageRank = 2;
|
||||
/// Maps (i, j) to (i, j)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return coord;
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for column-major matrices
|
||||
struct ColumnMajor {
|
||||
static int const kStorageRank = 2;
|
||||
/// Maps (i, j) to (j, i)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(coord.column(), coord.row());
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for interleaved matrices. Matrix is structured
|
||||
/// as row-major arrangement of fixed-size columns.
|
||||
template <int Interleave>
|
||||
struct RowMajorInterleaved {
|
||||
|
||||
/// Rank of storage n-D array
|
||||
static int const kStorageRank = 3;
|
||||
|
||||
/// Interleaving size
|
||||
static int const kInterleave = Interleave;
|
||||
|
||||
/// Maps (row, col) to (row, col, row)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(
|
||||
coord.row() / kInterleave,
|
||||
coord.column(),
|
||||
coord.row() % kInterleave
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper to compute stride vector from leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(int ldm) {
|
||||
return make_Coord(
|
||||
ldm * kInterleave,
|
||||
kInterleave,
|
||||
1
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for interleaved matrices. Matrix is structured
|
||||
/// as column-major arrangement of fixed-size rows.
|
||||
template <int Interleave>
|
||||
struct ColumnMajorInterleaved {
|
||||
|
||||
/// Rank of storage n-D array
|
||||
static int const kStorageRank = 3;
|
||||
|
||||
/// Interleaving size
|
||||
static int const kInterleave = Interleave;
|
||||
|
||||
/// Maps (row, col) to (col, row, col)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(
|
||||
coord.column() / kInterleave,
|
||||
coord.row(),
|
||||
coord.column() % kInterleave
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper to compute stride vector from leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(int ldm) {
|
||||
return make_Coord(
|
||||
ldm * kInterleave,
|
||||
kInterleave,
|
||||
1
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for scenario in which layout is row-major or column-major but this information
|
||||
/// is only available at runtime.
|
||||
struct ContiguousLayout {
|
||||
/// Arbitrary storage rank
|
||||
static int const kStorageRank = 3;
|
||||
|
||||
/// Dimension of rows
|
||||
static int const kRow = 0;
|
||||
|
||||
/// Dimension of columns
|
||||
static int const kColumn = 1;
|
||||
|
||||
/// Mapping function defined by runtime variable. Returns coordinates in n-D storage array
|
||||
/// as (matrix row, matrix colum, 0)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(coord.row(), coord.column(), 0);
|
||||
}
|
||||
|
||||
/// Helper to construct a stride vector based on contiguous matrix layout and leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(MatrixLayout::Kind layout, int ldm) {
|
||||
if (layout == MatrixLayout::kRowMajor) {
|
||||
return make_Coord(ldm, 1, 1);
|
||||
}
|
||||
return make_Coord(1, ldm, 1);
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for block-linear matrices. Matrix is structured
|
||||
/// as column-major arrangement of 2D tiles (that are column-major).
|
||||
template <int BlockRows, int BlockColumns>
|
||||
struct ColumnMajorBlockLinear {
|
||||
|
||||
/// Rank of storage n-D array
|
||||
static int const kStorageRank = 4;
|
||||
|
||||
/// Interleaving size in rows dimension
|
||||
static int const kBlockRows = BlockRows;
|
||||
|
||||
/// Interleaving size in columns dimension
|
||||
static int const kBlockColumns = BlockColumns;
|
||||
|
||||
/// Maps (row, col) to (col, row, col, row)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(
|
||||
coord.column() / kBlockColumns,
|
||||
coord.row() / kBlockRows,
|
||||
coord.column() % kBlockColumns,
|
||||
coord.row() % kBlockRows
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper to compute stride vector from leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(int ldm) {
|
||||
return make_Coord(
|
||||
ldm * kBlockRows * kBlockColumns,
|
||||
kBlockRows * kBlockColumns,
|
||||
kBlockRows,
|
||||
1
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for block-linear matrices. Matrix is structured
|
||||
/// as row-major arrangement of 2D tiles (that are row-major)
|
||||
template <int BlockRows, int BlockColumns>
|
||||
struct RowMajorBlockLinear {
|
||||
|
||||
/// Rank of storage n-D array
|
||||
static int const kStorageRank = 4;
|
||||
|
||||
/// Interleaving size in rows dimension
|
||||
static int const kBlockRows = BlockRows;
|
||||
|
||||
/// Interleaving size in columns dimension
|
||||
static int const kBlockColumns = BlockColumns;
|
||||
|
||||
/// Maps (row, col) to (row, col, row, col)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(
|
||||
coord.row() / kBlockRows,
|
||||
coord.column() / kBlockColumns,
|
||||
coord.row() % kBlockRows,
|
||||
coord.column() % kBlockColumns
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper to compute stride vector from leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(int ldm) {
|
||||
return make_Coord(
|
||||
ldm * kBlockRows * kBlockColumns,
|
||||
kBlockRows * kBlockColumns,
|
||||
kBlockColumns,
|
||||
1
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -45,4 +359,14 @@ struct GemmOperand {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Transformation applied to matrix operands
|
||||
struct MatrixTransform {
|
||||
enum Kind {
|
||||
kNone, /// no operation
|
||||
kConjugate, /// conjugate
|
||||
};
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -28,12 +28,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -114,7 +115,7 @@ struct PredicateVector {
|
||||
// Make sure no one tries to put more than 8 bits in a byte :)
|
||||
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
|
||||
// Make sure the "offsetted" bits fit in one byte.
|
||||
static_assert(kPredicateStart + kPredicatesPerByte < 8,
|
||||
static_assert(kPredicateStart + kPredicatesPerByte <= 8,
|
||||
"The offsetted predicates must fit within an actual byte.");
|
||||
|
||||
/// Storage type of individual elements
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/shape.h>
|
||||
#include "cutlass/shape.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -128,6 +128,17 @@ struct ShapeDiv {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeDivCeiling {
|
||||
typedef Shape<(A_::kD + B_::kD - 1) / B_::kD,
|
||||
(A_::kH + B_::kH - 1) / B_::kH,
|
||||
(A_::kW + B_::kW - 1) / B_::kW,
|
||||
(A_::kC + B_::kC - 1) / B_::kC>
|
||||
Shape;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename A_, typename B_>
|
||||
struct ShapeMax {
|
||||
typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
|
||||
@ -150,12 +161,12 @@ struct ShapeMin {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape_, int kElementsPerAccess>
|
||||
template <typename Shape_, int elementsPerAccess>
|
||||
struct ShapeStrides {
|
||||
typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
|
||||
Shape_::kW * Shape_::kC,
|
||||
Shape_::kC,
|
||||
kElementsPerAccess>
|
||||
elementsPerAccess>
|
||||
Shape;
|
||||
};
|
||||
|
||||
@ -167,7 +178,7 @@ struct ShapeStrides {
|
||||
*/
|
||||
template <typename Shape_>
|
||||
struct ComputeOffsetFromShape {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
|
||||
// clang-format off
|
||||
return d * Shape_::kH * Shape_::kW * Shape_::kC +
|
||||
h * Shape_::kW * Shape_::kC +
|
||||
@ -179,73 +190,19 @@ struct ComputeOffsetFromShape {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
* @tparam kSc Separation between two elements in "elements"
|
||||
*/
|
||||
template <int kSh_, int kSw_, int kSc_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, kSc_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * kSw_ * kSc_ + w * kSc_ + c;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam kSh Elements in the H dimension
|
||||
* @tparam kSw Elements in the W dimension
|
||||
*/
|
||||
template <int kSh_, int kSw_>
|
||||
struct ComputeOffsetFromShape<Shape<1, kSh_, kSw_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * kSw_ + w; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube
|
||||
* @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride.
|
||||
*/
|
||||
template <typename Strides_>
|
||||
struct ComputeOffsetFromStrides {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
|
||||
return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
* @tparam S_c Stride between two scalars.
|
||||
*/
|
||||
template <int S_h_, int S_w_, int S_c_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, S_c_> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) {
|
||||
return h * S_h_ + w * S_w_ + c * S_c_;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1
|
||||
* @tparam S_h Stride in the H dimension in scalars
|
||||
* @tparam S_w Stride in the W dimension in scalars
|
||||
*/
|
||||
template <int S_h_, int S_w_>
|
||||
struct ComputeOffsetFromStrides<Shape<1, S_h_, S_w_, 1> > {
|
||||
static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * S_h_ + w * S_w_; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_.
|
||||
* Afterwards compute the offset of those coordinates using Strides_
|
||||
|
||||
@ -27,125 +27,613 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <typeinfo>
|
||||
|
||||
#include <cutlass/coord.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/vector.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure modeling a pointer and stride into a tensor
|
||||
template <typename Storage_, int Rank_>
|
||||
/// Default mapping function from coordinates in a tensor's index space into the n-D array held
|
||||
/// in memory. Assumes StorageRank = Rank
|
||||
template <int Rank>
|
||||
struct IdentityTensorMapFunc {
|
||||
static int const kStorageRank = Rank;
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank> operator()(Coord<Rank> const &coord) const {
|
||||
return coord;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/* \brief Structure modeling a pointer and stride into a tensor.
|
||||
|
||||
A tensor consists of an index space with Rank_ dimensions. It is stored in memory modeled
|
||||
as an n-D array, where n = StorageRank_. A mapping function maps the logical coordinates of the
|
||||
tensor's index space into the n-D array, and a stride vector maps the n-D array to linear memory.
|
||||
|
||||
CUTLASS requires the n-D array's least significant, "fastest changing" dimension to
|
||||
be contiguous in memory. It therefore has a stride of 1 and is not stored. Construction is offered
|
||||
from vectors of full StorageRank and of the 'compact' rank, though it is in error to construct
|
||||
with the least significant stride != 1.
|
||||
|
||||
The requirement that the least significant dimension be consecutive enables numerous optimizations
|
||||
and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various
|
||||
BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2
|
||||
matrix is provided.
|
||||
|
||||
This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all
|
||||
stride elements are > 1. This can be overcome by defining a custom mapping function and a
|
||||
StorageRank of 3 or more.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
(These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h)
|
||||
|
||||
1. Column-major matrix may be represented as a rank=2 tensor:
|
||||
|
||||
TensorRef<float, 2, MatrixLayout::ColumnMajor> A(ptr_A, make_Coord(ldm, 1));
|
||||
|
||||
2. Row-major matrix may be represented as a rank=2 tensor:
|
||||
|
||||
TensorRef<float, 2, MatrixLayout::RowMajor> B(ptr_A, ldm);
|
||||
|
||||
3. An interleaved matrix may be represented as a rank=2 tensor:
|
||||
|
||||
TensorRef<int8_t, 2, MatrixLayout::ColumnMajorInterleaved<32> > C;
|
||||
|
||||
4. Defining a sparse matrix with arbitrary strides in each dimension
|
||||
|
||||
struct ContiguousLayout {
|
||||
|
||||
/// Arbitrary storage rank
|
||||
static int const kStorageRank = 3;
|
||||
|
||||
/// Mapping function defined by runtime stride configuration
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(coord.row(), coord.column(), 0);
|
||||
}
|
||||
};
|
||||
|
||||
typedef TensorRef<float, 2, ContiguousLayout> ContiguousTensorRef;
|
||||
|
||||
// Construct the TensorRef object from a pair of stride values
|
||||
ContiguousTensorRef D(ptr_D, make_Coord(row_stride, column_stride));
|
||||
|
||||
|
||||
5. A helper exists to define a TensorRef for a contiguous matrix whose layout
|
||||
is not known at compile time.
|
||||
|
||||
MatrixLayout::Kind layout; // Could be MatrixLayout::kRowMajor or MatrixLayout::kColumnMajor
|
||||
int ldm; // leading dimension
|
||||
|
||||
ContiguousTensorRef E(ptr_E, ContiguousLayout::stride(layout, ldm));
|
||||
|
||||
*/
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage_,
|
||||
/// Rank of logical tensor
|
||||
int Rank_,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
|
||||
/// Rank of internal n-D array
|
||||
int StorageRank_ = MapFunc_::kStorageRank,
|
||||
/// Index type used for coordinates
|
||||
typename Index_ = int,
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_ = long long
|
||||
>
|
||||
class TensorRef {
|
||||
public:
|
||||
/// Data type of individual access
|
||||
typedef Storage_ Storage;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = Rank_;
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = Rank_;
|
||||
|
||||
/// Mapping function from logical coordinate to internal n-D array
|
||||
typedef MapFunc_ MapFunc;
|
||||
|
||||
/// Rank of internal storage
|
||||
static int const kStorageRank = StorageRank_;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Typically, strides in memory can be very large
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef Coord<kRank> TensorCoord;
|
||||
|
||||
/// Coordinate in storage n-D array
|
||||
typedef Coord<kStorageRank> StorageCoord;
|
||||
|
||||
/// Stride vector in storage coordinage space - assumes least significant stride
|
||||
/// is 1 and does not store it.
|
||||
typedef Coord<kStorageRank - 1> StrideVector;
|
||||
|
||||
/// Tensor reference to of constant value
|
||||
typedef TensorRef<
|
||||
typename platform::remove_const<Storage>::type const,
|
||||
Rank_,
|
||||
MapFunc_,
|
||||
StorageRank_,
|
||||
Index_,
|
||||
LongIndex_> ConstTensorRef;
|
||||
|
||||
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
||||
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
||||
/// extensive C++ metaprogramming or support for zero-length arrays.
|
||||
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
||||
|
||||
//
|
||||
// Definitions included for backwards compatibility - to be removed in next major release
|
||||
//
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef TensorCoord Coord_t;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const Rank = kRank;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to storage element
|
||||
/// Pointer
|
||||
Storage* ptr_;
|
||||
|
||||
/// Stride information
|
||||
Coord<Rank> stride_;
|
||||
/// Stride vector - fastest-changing stride assumed to be 1 and not stored
|
||||
StrideVector stride_;
|
||||
|
||||
/// Maps a logical coordinate to an n-D array's tensor space
|
||||
MapFunc coord_map_;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
/// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef() : ptr_(nullptr) {}
|
||||
TensorRef(Storage *ptr = nullptr): ptr_(ptr) {
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride_[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from a pointer, size, and stride
|
||||
/// Helper to construct from a pointer and single stride element for 2-D pitch linear memory.
|
||||
// Higher ranks are projected onto the fastest-changing rank.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, Coord<Rank> stride) : ptr_(ptr), stride_(stride) {}
|
||||
TensorRef(Storage* ptr, Index ldm) {
|
||||
ptr_ = ptr;
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride_[i] = ldm;
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs from a single pointer and stride vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr), stride_(stride) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs from a pointer and a stride vector of size kRank. If fastest changing
|
||||
/// stride is not 1, construction fails and subsequent calls to good() will return false.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, StorageCoord const& stride) {
|
||||
// Fastest-changing stride must be one
|
||||
if (stride.at(kStorageRank - 1) == 1) {
|
||||
ptr_ = ptr;
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride_[i] = stride[i];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Fastest-chaning stride must be 1.
|
||||
reset();
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables conversion from TensorRef of non-const type
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(
|
||||
TensorRef<
|
||||
typename platform::remove_const<Storage>::type,
|
||||
kRank,
|
||||
MapFunc,
|
||||
kStorageRank,
|
||||
Index,
|
||||
LongIndex> const &ref
|
||||
):
|
||||
ptr_(ref.data()) {
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride_[i] = ref.stride(i);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a reference to constant-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(*this);
|
||||
}
|
||||
|
||||
/// Updates only the pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr = nullptr) {
|
||||
ptr_ = ptr;
|
||||
}
|
||||
|
||||
/// Updates the pointer, stride, and location within a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr = nullptr, Coord<Rank> stride = Coord<Rank>(0)) {
|
||||
ptr_ = ptr;
|
||||
stride_ = stride;
|
||||
}
|
||||
|
||||
/// Conversion function
|
||||
template <typename T>
|
||||
TensorRef<T, Rank> convert() {
|
||||
Coord<Rank> converted_stride;
|
||||
for (int i = 0; i < Rank - 1; ++i) {
|
||||
converted_stride[i] = stride_[i] * Extent<Storage>::kValue / Extent<T>::kValue;
|
||||
void reset(Storage* ptr, StorageCoord const & stride) {
|
||||
// Fastest-changing stride must be one
|
||||
if (stride.at(kStorageRank - 1) == 1) {
|
||||
ptr_ = ptr;
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride_[i] = stride[i];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Fastest-changing stride must be 1 - this is an error.
|
||||
reset();
|
||||
}
|
||||
converted_stride[Rank - 1] = stride_[Rank - 1];
|
||||
|
||||
return TensorRef<T, Rank>(reinterpret_cast<T*>(ptr_), converted_stride);
|
||||
}
|
||||
|
||||
/// Returns true if the TensorRef may be safely accessed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ptr_ != nullptr; }
|
||||
bool good() const {
|
||||
return ptr_ != nullptr;
|
||||
}
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage* data() const { return ptr_; }
|
||||
Storage * data() const { return ptr_; }
|
||||
|
||||
/// Returns the stride of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<Rank> const& stride() const { return stride_; }
|
||||
StorageCoord stride() const {
|
||||
StorageCoord ld;
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
ld[i] = stride_[i];
|
||||
}
|
||||
ld[kStorageRank - 1] = 1;
|
||||
return ld;
|
||||
}
|
||||
|
||||
/// Returns the stride of the tensor in the given dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return stride_.at(dim); }
|
||||
Index stride(int dim) const {
|
||||
// fastest-changing stride assumbed to be 1
|
||||
if (dim + 1 >= kStorageRank) {
|
||||
return 1;
|
||||
}
|
||||
return stride_.at(dim);
|
||||
}
|
||||
|
||||
/// Returns the maximum stride element as the 'leading dimension'
|
||||
CUTLASS_HOST_DEVICE
|
||||
int leading_dim() const { return __NV_STD_MAX(stride_[1], stride_[2]); }
|
||||
Index leading_dim(int idx = 0) const { return stride(idx); }
|
||||
|
||||
/// Maps a logical coordinate to an n-D array in memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
StorageCoord map(TensorCoord const &coord) const {
|
||||
return coord_map_(coord);
|
||||
}
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
long long offset(Coord<Rank> const& coord) const {
|
||||
return stride_.template dot<long long>(coord);
|
||||
LongIndex offset(TensorCoord const& coord) const {
|
||||
return stride().template dot<LongIndex>(map(coord));
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(Coord<Rank> const& coord) const { return ptr_[offset(coord)]; }
|
||||
Storage& at(TensorCoord const& coord) const {
|
||||
return ptr_[offset(coord)];
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
/// Returns a reference to the element at a given linear index
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(LongIndex idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(int idx) const { return ptr_[idx]; }
|
||||
Storage& operator[](TensorCoord const& coord) const {
|
||||
return ptr_[offset(coord)];
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
Storage& operator[](int idx) const { return at(idx); }
|
||||
|
||||
/// Adds an offset to the pointer
|
||||
/// Returns a reference to the element at a given linear index
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& advance(Coord<Rank> const& b) {
|
||||
ptr_ += offset(b);
|
||||
Storage& operator[](LongIndex idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Adds an offset to each pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef & add_pointer_offset(LongIndex delta) {
|
||||
ptr_ += delta;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator+(Coord<Rank> const& b) const { return TensorRef(ptr_ + offset(b), stride_); }
|
||||
TensorRef operator+(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator-(Coord<Rank> const& b) const { return TensorRef(ptr_ - offset(b), stride_); }
|
||||
TensorRef& operator+=(TensorCoord const& b) {
|
||||
add_pointer_offset(offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator-(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(-offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& operator-=(TensorCoord const& b) {
|
||||
add_pointer_offset(-offset(b));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Partial specializations to handle degenerate cases.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
/// Specialization for rank=1 case with no internal StrideVector
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage_,
|
||||
/// Rank of logical tensor
|
||||
int Rank_,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc_,
|
||||
/// Index type used for coordinates
|
||||
typename Index_,
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_
|
||||
>
|
||||
class TensorRef<Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_> {
|
||||
public:
|
||||
/// Data type of individual access
|
||||
typedef Storage_ Storage;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const kRank = Rank_;
|
||||
|
||||
/// Mapping function from logical coordinate to internal n-D array
|
||||
typedef MapFunc_ MapFunc;
|
||||
|
||||
/// Rank of internal storage
|
||||
static int const kStorageRank = 1;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Typically, strides in memory can be very large
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef Coord<kRank> TensorCoord;
|
||||
|
||||
/// Coordinate in storage n-D array
|
||||
typedef Coord<kStorageRank> StorageCoord;
|
||||
|
||||
/// Stride vector in storage coordinage space - assumes least significant stride
|
||||
/// is 1 and does not store it.
|
||||
struct StrideVector { };
|
||||
|
||||
/// Tensor reference to of constant value
|
||||
typedef TensorRef<
|
||||
typename platform::remove_const<Storage>::type const,
|
||||
Rank_,
|
||||
MapFunc_,
|
||||
kStorageRank,
|
||||
Index_,
|
||||
LongIndex_> ConstTensorRef;
|
||||
|
||||
//
|
||||
// Definitions included for backwards compatibility - to be removed in next major release
|
||||
//
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef TensorCoord Coord_t;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const Rank = kRank;
|
||||
|
||||
private:
|
||||
|
||||
/// Pointer
|
||||
Storage* ptr_;
|
||||
|
||||
/// Maps a logical coordinate to an n-D array's tensor space
|
||||
MapFunc coord_map_;
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage *ptr = nullptr): ptr_(ptr) { }
|
||||
|
||||
/// Constructs from a single pointer and stride vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr) {
|
||||
|
||||
}
|
||||
|
||||
/// Constructs from a pointer and a stride vector of size kRank. If fastest changing
|
||||
/// stride is not 1, construction fails and subsequent calls to good() will return false.
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(Storage* ptr, StorageCoord const& stride) {
|
||||
// Fastest-changing stride must be one
|
||||
if (stride.at(kStorageRank - 1) == 1) {
|
||||
ptr_ = ptr;
|
||||
}
|
||||
else {
|
||||
// Fastest-chaning stride must be 1.
|
||||
reset();
|
||||
}
|
||||
}
|
||||
|
||||
/// Enables conversion from TensorRef of non-const type
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef(
|
||||
TensorRef<
|
||||
typename platform::remove_const<Storage>::type,
|
||||
kRank,
|
||||
MapFunc,
|
||||
kStorageRank,
|
||||
Index,
|
||||
LongIndex> const &ref
|
||||
):
|
||||
ptr_(ref.data()) {
|
||||
}
|
||||
|
||||
/// Returns a reference to constant-valued tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(*this);
|
||||
}
|
||||
|
||||
/// Updates only the pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr = nullptr) {
|
||||
ptr_ = ptr;
|
||||
}
|
||||
|
||||
/// Updates the pointer, stride, and location within a TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(Storage* ptr, StorageCoord const & stride) {
|
||||
// Fastest-changing stride must be one
|
||||
if (stride.at(kStorageRank - 1) == 1) {
|
||||
ptr_ = ptr;
|
||||
}
|
||||
else {
|
||||
// Fastest-changing stride must be 1 - this is an error.
|
||||
reset();
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the TensorRef may be safely accessed
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const {
|
||||
return ptr_ != nullptr;
|
||||
}
|
||||
|
||||
/// Returns the pointer to referenced data
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage * data() const { return ptr_; }
|
||||
|
||||
/// Returns the stride of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
StorageCoord stride() const {
|
||||
StorageCoord ld;
|
||||
ld[kStorageRank - 1] = 1;
|
||||
return ld;
|
||||
}
|
||||
|
||||
/// Returns the stride of the tensor in the given dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index stride(int dim) const {
|
||||
// fastest-changing stride assumbed to be 1
|
||||
return 1;
|
||||
}
|
||||
|
||||
/// Returns the maximum stride element as the 'leading dimension'
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index leading_dim(int idx = 0) const { return 1; }
|
||||
|
||||
/// Maps a logical coordinate to an n-D array in memory
|
||||
CUTLASS_HOST_DEVICE
|
||||
StorageCoord map(TensorCoord const &coord) const {
|
||||
return coord_map_(coord);
|
||||
}
|
||||
|
||||
/// Computes the offset of an index from the origin of the tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex offset(TensorCoord const& coord) const {
|
||||
return stride().template dot<LongIndex>(map(coord));
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(TensorCoord const& coord) const {
|
||||
return ptr_[offset(coord)];
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given linear index
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& at(LongIndex idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Returns a reference to the element at a given Coord
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& operator[](TensorCoord const& coord) const {
|
||||
return ptr_[offset(coord)];
|
||||
}
|
||||
|
||||
/// Returns a reference to the element at a given linear index
|
||||
CUTLASS_HOST_DEVICE
|
||||
Storage& operator[](LongIndex idx) const { return ptr_[idx]; }
|
||||
|
||||
/// Adds an offset to each pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef & add_pointer_offset(LongIndex delta) {
|
||||
ptr_ += delta;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator+(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& operator+=(TensorCoord const& b) {
|
||||
add_pointer_offset(offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef operator-(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(-offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef& operator-=(TensorCoord const& b) {
|
||||
add_pointer_offset(-offset(b));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
420
cutlass/tensor_ref_collection.h
Normal file
@ -0,0 +1,420 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// TensorRefCollection is a concept for storing a logical collection of TensorRef objects. Classes
|
||||
// satisfying the TensorRefCollection concept must support the following:
|
||||
//
|
||||
// // Define storage type
|
||||
// typedef typename TensorRefCollection::Storage Storage;
|
||||
//
|
||||
// // Define a type for offsets in memory
|
||||
// typedef typename TensorRefCollection::LongIndex LongIndex;
|
||||
//
|
||||
// // Define a ConstIterator type satisfying TensorRefIterator
|
||||
// typedef typename TensorRefCollection::ConstIterator TensorRefIterator;
|
||||
//
|
||||
// // Implement a begin() method.
|
||||
// TensorRefIterator iterator = collection.begin();
|
||||
//
|
||||
//
|
||||
// TensorRefIterator is a concept for accessing an element in a TensorRefCollection. Classes
|
||||
// satisfying the TensorRefIterator concept must support the following:
|
||||
//
|
||||
// // Define a TensorRef type accessed by the iterator
|
||||
// typedef typename TensorRefIterator::TensorRef TensorRef;
|
||||
//
|
||||
// // Access the TensorRef
|
||||
// TensorRef ref = *iterator;
|
||||
//
|
||||
// // Pre-increment and post-increment
|
||||
// ++iterator;
|
||||
// iterator++;
|
||||
//
|
||||
// // Pre-decrement and post-decrement
|
||||
// --iterator;
|
||||
// iterator--;
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This satisfies TensorRefCollection and stores a collection of TensorRef objects that
|
||||
/// have identical strides. TensorRef objects are separated by a linear stride.
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage_,
|
||||
/// Rank of logical tensor
|
||||
int Rank_,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
|
||||
/// Rank of internal n-D array
|
||||
int StorageRank_ = MapFunc_::kStorageRank,
|
||||
/// Index type used for coordinates
|
||||
typename Index_ = int,
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_ = long long
|
||||
>
|
||||
struct TensorRefBatchStrided:
|
||||
public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
|
||||
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Underlying TensorRef type
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> Base;
|
||||
|
||||
/// Storage type
|
||||
typedef typename Base::Storage Storage;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Typically, strides in memory can be very large
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef Coord<kRank> TensorCoord;
|
||||
|
||||
/// Tensor reference implied by the TensorRefBatchStrided
|
||||
typedef Base TensorRef;
|
||||
|
||||
/// Constant iterator over tensors implied by TensorRefBatchStrided
|
||||
class ConstIterator {
|
||||
public:
|
||||
/// TensorRef returned by the iterator
|
||||
typedef Base TensorRef;
|
||||
|
||||
private:
|
||||
|
||||
/// Reference to the parent TensorBatchRef object
|
||||
TensorRefBatchStrided const &ref_;
|
||||
|
||||
/// Offset from the base TensorRef pointer
|
||||
LongIndex offset_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs a ConstIterator from a parent TensorRefBatchStrided
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(
|
||||
TensorRefBatchStrided const &ref,
|
||||
LongIndex offset = 0): ref_(ref), offset_(offset) { }
|
||||
|
||||
/// Obtains a TensorRef pointed to by the iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef *operator() const {
|
||||
TensorRef ref(ref_);
|
||||
ref.add_pointer_offset(offset_);
|
||||
return ref;
|
||||
}
|
||||
|
||||
/// Advances the iterator to point to the next tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator++() {
|
||||
offset_ += ref_.tensor_stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances the iterator to point to the next tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator++(int) {
|
||||
ConstIterator ret(*this);
|
||||
offset_ += ref_.tensor_stride;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns an iterator advanced by (idx) amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator+(Index idx) {
|
||||
return ConstIterator(ref, offset_ + ref_.tensor_stride * idx);
|
||||
}
|
||||
|
||||
/// Advances this iterator by (idx) and returns a reference to self
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator+=(Index idx) {
|
||||
offset_ += ref_.tensor_stride * idx;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Moves to the previous tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator--() {
|
||||
offset_ -= ref_.tensor_stride;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Moves to the previous tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator--(int) {
|
||||
ConstIterator ret(*this);
|
||||
offset_ -= ref_.tensor_stride;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/// Returns an iterator moved forward by (idx) amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator-(Index idx) {
|
||||
return ConstIterator(ref_, offset_ - ref_.tensor_stride * idx);
|
||||
}
|
||||
|
||||
/// Moves this iterator by (idx) and returns a reference to self
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator-=(Index idx) {
|
||||
offset_ -= ref_.tensor_stride * idx;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the difference in offset between two iterators
|
||||
CUTLASS_HOST_DEVICE
|
||||
Stride operator-(ConstIterator const &it) {
|
||||
return offset_ - it.offset_;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Stride between tensors
|
||||
LongIndex tensor_stride;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefBatchStrided(): tensor_stride(0) { }
|
||||
|
||||
// Constructs form a tensor reference and
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
|
||||
TensorRef(ref),
|
||||
tensor_stride(_tensor_stride) { }
|
||||
|
||||
/// Gets the pointer offset
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex get_pointer_offset(Index idx) const {
|
||||
return idx * tensor_stride;
|
||||
}
|
||||
|
||||
// Returns a reference
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef at(Index idx) const {
|
||||
TensorRef ref(*this);
|
||||
ref.add_pointer_offset(get_pointer_offset(idx));
|
||||
return ref;
|
||||
}
|
||||
|
||||
/// Returns an iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator begin() {
|
||||
return ConstIterator(*this);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// This satisfies TensorRefCollection and stores a collection of TensorRef objects. This is a
|
||||
/// structure of arrays in that the individual members of the TensorRef are held in distinct arrays.
|
||||
///
|
||||
/// Note, TensorRef maps a logical coordinate space to an n-D array with rank kStorageRank. It
|
||||
/// maintains a stride vector of similar rank, but the least significant rank is defined to be 1.
|
||||
///
|
||||
/// The least significant stride of 1 is not stored, and therefore the number of stride arrays is
|
||||
/// kStorageRank - 1.
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage_,
|
||||
/// Rank of logical tensor
|
||||
int Rank_,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
|
||||
/// Rank of internal n-D array
|
||||
int StorageRank_ = MapFunc_::kStorageRank,
|
||||
/// Index type used for coordinates
|
||||
typename Index_ = int,
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_ = long long
|
||||
>
|
||||
struct TensorRefArray {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// TensorRef type obtained from the TensorRefArray
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> TensorRef;
|
||||
|
||||
/// Element pointed to by the TensorRef
|
||||
typedef Storage_ Storage;
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Typically, strides in memory can be very large
|
||||
typedef LongIndex_ LongIndex;
|
||||
|
||||
/// Rank of the stride vector
|
||||
static int const kStorageRank = TensorRef::kStorageRank;
|
||||
|
||||
/// TensorRefIterator over TensorRef objects in TensorRefArray
|
||||
class ConstIterator {
|
||||
public:
|
||||
|
||||
/// TensorRef returned by the iterator
|
||||
typedef Base TensorRef;
|
||||
|
||||
private:
|
||||
/// Reference to the TensorRefArray
|
||||
TensorRefArray const &ref_;
|
||||
|
||||
/// Index into TensorRefArray
|
||||
int idx_;
|
||||
|
||||
public:
|
||||
|
||||
/// Constructs a ConstIterator over the TensorRef objects
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator(TensorArrayRef const &ref, int idx = 0): ref_(ref), idx_(idx) { }
|
||||
|
||||
/// Obtains a TensorRef pointed to by this iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef *operator() const {
|
||||
return ref_.reference(idx_);
|
||||
}
|
||||
|
||||
/// Advances to next TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator++() {
|
||||
++idx_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to next TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator++(int) {
|
||||
ConstIterator ret(*this);
|
||||
idx_ ++;
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator+(Index idx) {
|
||||
return ConstIterator(ref_, idx_ + idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator+=(Index idx) {
|
||||
idx_ += idx;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator--() {
|
||||
--idx_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Advances to next TensorRef
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator--(int) {
|
||||
ConstIterator ret(*this);
|
||||
--idx_;
|
||||
return ret;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator &operator-=(Index idx) {
|
||||
idx_ -= idx;
|
||||
return *this;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator operator-(Index idx) {
|
||||
return ConstIterator(ref_, idx_ + idx);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Base addresses
|
||||
Storage **pointers;
|
||||
|
||||
/// Array of strides
|
||||
Index *strides[kStorageRank - 1];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorArrayRef() { }
|
||||
|
||||
// Construct from pointers to arrays to strides
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorArrayRef(
|
||||
Storage **_pointers,
|
||||
Index _strides[kStorageRank - 1]): pointers(_pointers) {
|
||||
|
||||
// Copy pointers to strides arrays
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
strides[i] = _strides[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a TensorRef at the given index in the collection
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef at(Index idx) const {
|
||||
Coord<kStorageRank - 1, Index> stride;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kStorageRank - 1; ++i) {
|
||||
stride[i] = stride_[idx][i];
|
||||
}
|
||||
return TensorRef(pointers[idx], stride);
|
||||
}
|
||||
|
||||
/// Returns an TesnorRefIterator over the TensorRef objects in this collection
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstIterator begin() {
|
||||
return ConstIterator(*this);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -24,51 +24,110 @@
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing strides and a pointer to tensor data.
|
||||
|
||||
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
|
||||
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
|
||||
data storage and is therefore lightweight and may be embedded in larger tensor objects or
|
||||
memory structures.
|
||||
|
||||
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
|
||||
linear memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Host-side reference implementation of tensor operations
|
||||
template <typename T>
|
||||
class TensorView : public TensorRef<T, 4> {
|
||||
/// Defines a view into a logical tensor
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage_,
|
||||
/// Rank of logical tensor
|
||||
int Rank_ = 4,
|
||||
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
|
||||
/// Rank of internal n-D array
|
||||
int StorageRank_ = MapFunc_::kStorageRank,
|
||||
/// Index type used for coordinates
|
||||
typename Index_ = int,
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex_ = long long
|
||||
>
|
||||
class TensorView : public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
|
||||
public:
|
||||
/// Reference and stride
|
||||
typedef TensorRef<T, 4> Base;
|
||||
/// Base tensor reference
|
||||
typedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> Base;
|
||||
|
||||
/// Reference and stride
|
||||
typedef Base TensorRef_t;
|
||||
/// Tensor reference to of constant value
|
||||
typedef TensorRef<
|
||||
typename platform::remove_const<Storage_>::type const,
|
||||
Rank_,
|
||||
MapFunc_,
|
||||
StorageRank_,
|
||||
Index_,
|
||||
LongIndex_> ConstTensorRef;
|
||||
|
||||
/// Reference to constant type
|
||||
typedef TensorRef<T const, 4> ConstTensorRef_t;
|
||||
/// Base tensor reference
|
||||
typedef Base TensorRef;
|
||||
|
||||
/// Rank of tensor
|
||||
static int const Rank = TensorRef_t::Rank;
|
||||
/// Storage type
|
||||
typedef typename Base::Storage Storage;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef typename TensorRef::TensorCoord TensorCoord;
|
||||
|
||||
/// Coordinate in storage n-D array
|
||||
typedef typename TensorRef::StorageCoord StorageCoord;
|
||||
|
||||
/// Stride vector in storage coordinate space
|
||||
/// Least significant stride is = 1 and not stored
|
||||
typedef typename TensorRef::StrideVector StrideVector;
|
||||
|
||||
/// TensorView of constant value
|
||||
typedef TensorView<
|
||||
typename platform::remove_const<Storage>::type const,
|
||||
Rank_,
|
||||
MapFunc_,
|
||||
StorageRank_,
|
||||
Index_,
|
||||
LongIndex_> ConstTensorView;
|
||||
|
||||
//
|
||||
// Definitions included for backwards compatibility - to be removed in next major release
|
||||
//
|
||||
|
||||
/// Coordinate in logical tensor space
|
||||
typedef TensorCoord Coord_t;
|
||||
|
||||
/// Logical rank of tensor index space
|
||||
static int const Rank = Base::kRank;
|
||||
|
||||
/// Type used to compute the offset of an element to the base of a tensor
|
||||
typedef int Offset_t;
|
||||
typedef typename Base::LongIndex Offset_t;
|
||||
|
||||
/// Coordinate into tensor
|
||||
typedef Coord<Rank> Coord_t;
|
||||
/// Base class
|
||||
typedef TensorRef TensorRef_t;
|
||||
|
||||
/// TensorRef to const-valued type
|
||||
typedef typename TensorRef::ConstTensorRef ConstTensorRef_t;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Pointer to pitch-linear memory
|
||||
TensorRef_t ref_;
|
||||
|
||||
/// Dimensions of coordinate (independent of stride)
|
||||
Coord_t size_;
|
||||
TensorCoord size_;
|
||||
|
||||
public:
|
||||
//
|
||||
@ -79,91 +138,126 @@ class TensorView : public TensorRef<T, 4> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView() {}
|
||||
|
||||
/// Constructs a Tensor_view from a TensorRef and size
|
||||
/// Constructs a TensorView from a TensorRef and size
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView(TensorRef_t const& _ref, Coord_t const& _size) : Base(_ref), size_(_size) {}
|
||||
TensorView(Base const& _ref, TensorCoord const& _size) : Base(_ref), size_(_size) {}
|
||||
|
||||
/// Returns true if the Tensor_view is bound to some memory
|
||||
/// Constructs a TensorView from a pointer, a stride vector, and size
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool good() const { return ref().good(); }
|
||||
TensorView(
|
||||
Storage *ptr,
|
||||
StrideVector const &stride,
|
||||
TensorCoord const& size
|
||||
):
|
||||
Base(ptr, stride), size_(size) {}
|
||||
|
||||
/// Returns a pointer to data
|
||||
/// Constructs a TensorView from a pointer, a stride vector, and size
|
||||
CUTLASS_HOST_DEVICE
|
||||
T* data() const { return ref().data(); }
|
||||
TensorView(
|
||||
Storage *ptr,
|
||||
StorageCoord const &stride,
|
||||
TensorCoord const& size
|
||||
):
|
||||
Base(ptr, stride), size_(size) {}
|
||||
|
||||
/// Updates the reference and size of a Tensor_view object
|
||||
CUTLASS_HOST_DEVICE
|
||||
void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
|
||||
void reset(Base const& _ref = Base(), TensorCoord const& _size = TensorCoord()) {
|
||||
Base::operator=(_ref);
|
||||
size_ = _size;
|
||||
}
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t& ref() { return *this; }
|
||||
|
||||
///
|
||||
CUTLASS_HOST_DEVICE
|
||||
ConstTensorRef_t const_ref() { return ConstTensorRef_t(data(), stride()); }
|
||||
|
||||
/// Accesses the tensor reference pointing to data
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRef_t const& ref() const { return *this; }
|
||||
TensorCoord const& size() const { return size_; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& size() const { return size_; }
|
||||
|
||||
/// Accesses the size
|
||||
CUTLASS_HOST_DEVICE
|
||||
int size(int dim) const { return size_.at(dim); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord_t const& stride() const { return ref().stride(); }
|
||||
|
||||
/// Accesses the stride
|
||||
CUTLASS_HOST_DEVICE
|
||||
int const& stride(int dim) const { return ref().stride(dim); }
|
||||
Index size(int dim) const { return size_.at(dim); }
|
||||
|
||||
/// Assigns the Tensor_view
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator=(TensorView const& _tensor) {
|
||||
Base::operator=(_tensor._ref);
|
||||
Base::operator=(_tensor);
|
||||
size_ = _tensor.size_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns the index of an element
|
||||
CUTLASS_HOST_DEVICE
|
||||
Offset_t offset(Coord_t const& coord) const { return ref().offset(coord); }
|
||||
|
||||
/// Determines whether a location is within a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool contains(Coord_t const& coord) const {
|
||||
for (int dim = 0; dim < Rank; ++dim) {
|
||||
if (coord.at(dim) >= size_.at(dim)) {
|
||||
bool contains(TensorCoord const& coord) const {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int dim = 0; dim < Rank_; ++dim) {
|
||||
if (coord[dim] >= size_[dim]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Coord_t const& coord) const { return ref().at(coord); }
|
||||
TensorRef ref() const {
|
||||
return TensorRef(*this);
|
||||
}
|
||||
|
||||
/// Element-wise accessor
|
||||
T& operator[](Coord<Rank> const& coord) const { return at(coord); }
|
||||
|
||||
/// Element-wise accessor
|
||||
/// Returns a TensorRef pointing to the first element of the tensor.
|
||||
CUTLASS_HOST_DEVICE
|
||||
T& at(Offset_t idx) const { return ref().at(idx); }
|
||||
ConstTensorRef const_ref() const {
|
||||
return ConstTensorRef(*this);
|
||||
}
|
||||
|
||||
/// Returns a Tensor_view given location and size quantities
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView<T> subview(Coord_t const& location, Coord_t size) const {
|
||||
return TensorView<T>(ref() + location, size.clamp(size_ - location));
|
||||
TensorView subview(TensorCoord const& location, TensorCoord size) const {
|
||||
return TensorView((*this) + location, size.clamp(size_ - location));
|
||||
}
|
||||
|
||||
/// Returns the number of scalar elements needed to store tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
size_t capacity() const {
|
||||
int max_rank = 0;
|
||||
|
||||
StorageCoord mapped_size(this->map(size()));
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < Base::kStorageRank; ++i) {
|
||||
if (!i ||
|
||||
this->stride(i) * mapped_size[i] > this->stride(max_rank) * mapped_size[max_rank]) {
|
||||
max_rank = i;
|
||||
}
|
||||
}
|
||||
return this->stride(max_rank) * mapped_size[max_rank];
|
||||
}
|
||||
|
||||
/// Returns a TensorView offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView operator+(TensorCoord const& b) const {
|
||||
TensorView result(*this);
|
||||
result.add_pointer_offset(this->offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator+=(TensorCoord const& b) {
|
||||
this->add_pointer_offset(this->offset(b));
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView operator-(TensorCoord const& b) const {
|
||||
TensorRef result(*this);
|
||||
result.add_pointer_offset(-this->offset(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a TensorRef offset by a given amount
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorView& operator-=(TensorCoord const& b) {
|
||||
this->add_pointer_offset(-this->offset(b));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
143
cutlass/tile_allocation.h
Normal file
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a fragment based on a Shape<> template.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Class for storing a tile in memory and accessing it through a tensor ref
|
||||
template <typename Scalar_, typename Shape_>
|
||||
struct TileAllocation {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Scalar element
|
||||
typedef Scalar_ Scalar;
|
||||
|
||||
/// The actual storage (may differ from the scalar type)
|
||||
typedef typename StorageType<sizeof(Scalar)>::Type Storage;
|
||||
|
||||
/// Size of the allocation in units of scalars
|
||||
typedef Shape_ Shape;
|
||||
|
||||
/// Strides
|
||||
typedef typename ShapeStrides<Shape, 1>::Shape Strides;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef TensorRef<Scalar const, 4> ConstTensorRef;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Storage
|
||||
Storage storage[Shape::kD][Shape::kH][Shape::kW][Shape::kC];
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a pointer to the raw data
|
||||
CUTLASS_DEVICE
|
||||
Scalar *data() { return reinterpret_cast<Scalar *>(&storage[0][0][0][0]); }
|
||||
|
||||
/// Returns a const pointer to the raw data
|
||||
CUTLASS_DEVICE
|
||||
Scalar const *data() const { return reinterpret_cast<Scalar const *>(&storage[0][0][0][0]); }
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference() {
|
||||
return TensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC));
|
||||
}
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
ConstTensorRef reference() const {
|
||||
return ConstTensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Manages a pair of tile allocations as if they are one allocation
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipTileAllocation {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// First tensor allocation
|
||||
typedef First_ First;
|
||||
|
||||
/// Second tensor allocation
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef ZipTensorRef<typename First::TensorRef, typename Second::TensorRef> TensorRef;
|
||||
|
||||
/// Defines the tensor reference for this allocation
|
||||
typedef ZipTensorRef<typename First::ConstTensorRef, typename Second::ConstTensorRef>
|
||||
ConstTensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First tensor allocation
|
||||
First first;
|
||||
|
||||
/// Second tensor allocation
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
TensorRef reference() { return TensorRef(first.reference(), second.reference()); }
|
||||
|
||||
/// Returns a TensorRef object pointing to the data
|
||||
CUTLASS_DEVICE
|
||||
ConstTensorRef reference() const { return ConstTensorRef(first.reference(), second.reference()); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
194
cutlass/tile_coord.h
Normal file
@ -0,0 +1,194 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a coordinate used for the CUTLASS 4-D tile structure.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// TileCoord wraps Coord<4, int> to provide a helper for accessing named dimensions. Classes
|
||||
/// expecting a coordinate in the rank=4 index space of a CUTLASS tile structure should use TileCoord.
|
||||
template <typename Index_ = int>
|
||||
struct TileCoord : public Coord<4, Index_> {
|
||||
|
||||
/// Index type
|
||||
typedef Index_ Index;
|
||||
|
||||
/// Underlying Coord<4>
|
||||
typedef Coord<4, Index> Base;
|
||||
|
||||
/// D dimension
|
||||
static int kD = 0;
|
||||
|
||||
/// H dimension
|
||||
static int kH = 1;
|
||||
|
||||
/// W dimension
|
||||
static int kW = 2;
|
||||
|
||||
/// C dimension
|
||||
static int kC = 3;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord() { }
|
||||
|
||||
/// Constructs from Coord<3> and infers coord[kC] = 0
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord(Coord<3, Index> const &coord):
|
||||
Base(make_Coord(coord[0], coord[1], coord[2], 0)) { }
|
||||
|
||||
/// Constructs from Coord<4>
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord(Coord<4, Index> const &coord): Base(coord) { }
|
||||
|
||||
/// Constructs from an array of coordinate elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord(Index coord[4]): Base(coord) { }
|
||||
|
||||
/// Helper to construct from a row and column
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord(Index d, Index h, Index w, Index c): Base(make_Coord(d, h, w, c)) { }
|
||||
|
||||
/// Returns the D element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & d() const { return this->at(kD); }
|
||||
|
||||
/// Returns the D element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & d() { return this->at(kD); }
|
||||
|
||||
/// Returns the H element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & h() const { return this->at(kH); }
|
||||
|
||||
/// Returns the H element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & h() { return this->at(kH); }
|
||||
|
||||
/// Returns the W element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & w() const { return this->at(kW); }
|
||||
|
||||
/// Returns the W element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & w() { return this->at(kW); }
|
||||
|
||||
/// Returns the Celement of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index const & c() const { return this->at(kC); }
|
||||
|
||||
/// Returns the C element of the coordinate
|
||||
CUTLASS_HOST_DEVICE
|
||||
Index & c() { return this->at(kC); }
|
||||
|
||||
/// Gets H and W dimensions as a Coord<2>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<2> hw() const {
|
||||
return make_Coord(h(), w());
|
||||
}
|
||||
|
||||
/// Gets H, W, and C dimensions as a Coord<3>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> hwc() const {
|
||||
return make_Coord(h(), w(), c());
|
||||
}
|
||||
|
||||
/// Gets D, H, and W dimensions as a Coord<3>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<3> dhw() const {
|
||||
return make_Coord(d(), h(), w());
|
||||
}
|
||||
|
||||
//
|
||||
// Coord operators
|
||||
//
|
||||
|
||||
/// Element-wise addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord operator+(Base const& b) const {
|
||||
return TileCoord(Base::operator+(b));
|
||||
}
|
||||
|
||||
/// Element-wise subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord operator-(Base const& b) const {
|
||||
return TileCoord(Base::operator-(b));
|
||||
}
|
||||
|
||||
/// Element-wise multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord operator*(Base const& b) const {
|
||||
return TileCoord(Base::operator*(b));
|
||||
}
|
||||
|
||||
/// Element-wise division
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord operator/(Base const& b) const {
|
||||
return TileCoord(Base::operator/(b));
|
||||
}
|
||||
|
||||
/// In-place addition
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord& operator+=(Base const& b) {
|
||||
Base::operator+=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place subtraction
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord& operator-=(Base const& b) {
|
||||
Base::operator-=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place multiplication
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord& operator*=(Base const& b) {
|
||||
Base::operator*=(b);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// In-place division
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileCoord& operator/=(Base const& b) {
|
||||
Base::operator/=(b);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
@ -28,10 +28,13 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/predicate_vector.h>
|
||||
#include <cutlass/vector.h>
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/load_store.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/vector.h"
|
||||
#include <cstdio>
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -61,12 +64,6 @@ as a Coord<4>.
|
||||
struct IteratorAdvance {
|
||||
enum Kind { kD, kH, kW };
|
||||
};
|
||||
|
||||
/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix
|
||||
struct IteratorFragment {
|
||||
enum Kind { kScalar, kWmmaMatrix };
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
@ -77,7 +74,7 @@ template <typename Tile_,
|
||||
typename Delta_,
|
||||
typename Iterations_,
|
||||
typename ThreadOffset_,
|
||||
int kAccessSize>
|
||||
int AccessSize>
|
||||
struct TileTraits {
|
||||
/// Shape of the tile
|
||||
typedef Tile_ Tile;
|
||||
@ -89,11 +86,52 @@ struct TileTraits {
|
||||
typedef Iterations_ Iterations;
|
||||
|
||||
/// Functor that returns the logical coordinate of each entity's initial offset in the tile
|
||||
//
|
||||
// ThreadOffset should be a functor defined like:
|
||||
//
|
||||
// struct ThreadOffsetExample {
|
||||
// CUTLASS_DEVICE
|
||||
// Coord<4> operator()() const {
|
||||
// return make_Coord(0, threadIdx.y, threadIdx.x, 0);
|
||||
// }
|
||||
// };
|
||||
//
|
||||
typedef ThreadOffset_ ThreadOffset;
|
||||
|
||||
/// Strides for immediate offset computation
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Access size
|
||||
static int const kAccessSize = AccessSize;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Functor computing a predicate given the logical position of an access
|
||||
template <typename Delta_>
|
||||
struct RegularTilePredicateFunctor {
|
||||
typedef Delta_ Delta;
|
||||
|
||||
/// Dimensions of the bounding volume
|
||||
Coord<3> bounds;
|
||||
|
||||
/// Constructs a predicate functor given the bounds of a tensor
|
||||
CUTLASS_HOST_DEVICE
|
||||
RegularTilePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {}
|
||||
|
||||
/// Computes the predicate given the logical position of an access
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool operator()(Coord<3> iteration, Coord<3> offset) const {
|
||||
return (iteration[0] * Delta::kD + offset[0] < bounds[0]) &&
|
||||
(iteration[1] * Delta::kH + offset[1] < bounds[1]) &&
|
||||
(iteration[2] * Delta::kW + offset[2] < bounds[2]);
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
struct DumpType {};
|
||||
/// Iterator for accessing a stripmined tile in memory
|
||||
template <typename Traits_,
|
||||
typename Scalar_,
|
||||
@ -101,7 +139,7 @@ template <typename Traits_,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileIteratorBase {
|
||||
/// concept TileTraits
|
||||
@ -117,7 +155,7 @@ struct TileIteratorBase {
|
||||
static IteratorAdvance::Kind const kAdvance = Advance_;
|
||||
|
||||
/// Specifies iterator storage fragment type (Scalar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = IteratorFragment_;
|
||||
static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = MemorySpace;
|
||||
@ -144,18 +182,19 @@ struct TileIteratorBase {
|
||||
typedef typename Traits::ThreadOffset ThreadOffset;
|
||||
|
||||
/// The number of scalars accessed per load/store.
|
||||
static int const kAccessSize = Tile::kC;
|
||||
static int const kAccessSize = Traits::kAccessSize;
|
||||
|
||||
/// The elements loaded/store by one instruction.
|
||||
typedef typename Vectorize<FragmentElement, kAccessSize>::Type AccessType;
|
||||
|
||||
/// The size of storage needed per fragment
|
||||
static int const kFragmentSize =
|
||||
(kIteratorFragment == IteratorFragment::kWmmaMatrix ? 16 : sizeof(AccessType));
|
||||
(kFragmentElementType == FragmentElementType::kWmmaMatrix ? 16 : sizeof(AccessType));
|
||||
/// The storage.
|
||||
typedef Fragment<Scalar, ShapeCount<Tile>::kCount, kFragmentSize> Storage;
|
||||
/// The fragment.
|
||||
typedef Fragment<FragmentElement, ShapeCount<Iterations>::kCount * kAccessSize> Fragment;
|
||||
|
||||
/// The fragment iterator.
|
||||
typedef FragmentIterator<Fragment, Iterations, AccessType> FragmentIterator;
|
||||
/// The fragment const iterator.
|
||||
@ -172,25 +211,61 @@ struct TileIteratorBase {
|
||||
|
||||
/// Parameters to the iterator
|
||||
struct Params {
|
||||
Index stride_d;
|
||||
|
||||
//
|
||||
// Dat members
|
||||
//
|
||||
|
||||
long long stride_d;
|
||||
Index stride_h;
|
||||
Index stride_w;
|
||||
|
||||
Index inc_d;
|
||||
long long inc_d;
|
||||
Index inc_h;
|
||||
Index inc_w;
|
||||
|
||||
Index inc_advance;
|
||||
long long inc_advance;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructs params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : stride_d(0), stride_h(0), stride_w(0), inc_d(0), inc_h(0), inc_w(0) {}
|
||||
|
||||
/// Constructs params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
long long _inc_advance)
|
||||
: stride_d(_stride_d),
|
||||
stride_h(_stride_h),
|
||||
stride_w(_stride_w),
|
||||
inc_d(_inc_d),
|
||||
inc_h(_inc_h),
|
||||
inc_w(_inc_w),
|
||||
inc_advance(_inc_advance) {}
|
||||
|
||||
/// Constructs params with a stride vector
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Coord<4> const &stride) {
|
||||
initialize(stride);
|
||||
}
|
||||
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d,
|
||||
int initialize(long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
long long _inc_advance) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
@ -203,61 +278,79 @@ struct TileIteratorBase {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes the parameters object from a vector of strides
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Index _stride_d, Index _stride_h, Index _stride_w) {
|
||||
int initialize(Coord<4> const &stride) {
|
||||
return initialize(stride[0], stride[1], stride[2]);
|
||||
}
|
||||
|
||||
/// Initializes the parameters object from a vector of strides
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(long long _stride_d, Index _stride_h, Index _stride_w) {
|
||||
stride_d = _stride_d;
|
||||
stride_h = _stride_h;
|
||||
stride_w = _stride_w;
|
||||
|
||||
inc_w = stride_w * Delta::kW;
|
||||
inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
inc_d = stride_d * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) -
|
||||
stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
inc_advance = 0;
|
||||
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
// Advance in the H dimension.
|
||||
inc_d = 0;
|
||||
inc_advance = Tile::kH * stride_h;
|
||||
} else if (kAdvance == IteratorAdvance::kW) {
|
||||
// Advance in the W dimension.
|
||||
inc_d = stride_w * Tile::kW - stride_h * Tile::kH;
|
||||
inc_advance = Tile::kW * stride_w;
|
||||
|
||||
} else {
|
||||
// Advance in the D dimension.
|
||||
inc_d = stride_d;
|
||||
inc_advance = Tile::kD * stride_d;
|
||||
}
|
||||
|
||||
inc_advance = 0;
|
||||
inc_advance -= stride_d * Delta::kD * (Iterations::kD - 1) +
|
||||
stride_h * Delta::kH * (Iterations::kH - 1) +
|
||||
stride_w * Delta::kW * (Iterations::kW - 1);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Gotta have this
|
||||
CUTLASS_HOST_DEVICE int initialize() {
|
||||
stride_d = 0;
|
||||
stride_h = 0;
|
||||
stride_w = 1;
|
||||
|
||||
inc_d = inc_h = inc_w = inc_advance = 0;
|
||||
inc_advance = 0;
|
||||
inc_d = inc_h = inc_w = 0;
|
||||
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/// Is the iterator valid?
|
||||
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
|
||||
|
||||
//
|
||||
// Static function members
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
CUTLASS_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &offset = make_Coord(0, 0, 0)) {
|
||||
template <typename PredicateIterator, typename PredicateFunctor>
|
||||
CUTLASS_HOST_DEVICE static void initialize_predicates(PredicateIterator predicate_it,
|
||||
PredicateFunctor const &predicate_func,
|
||||
Coord<3> const &offset) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
bool enable_d = (d * Delta::kD + offset[0] < bounds[0]);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
bool enable_h = (h * Delta::kH + offset[1] < bounds[1]);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]);
|
||||
predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w);
|
||||
bool enable = predicate_func(make_Coord(d, h, w), offset);
|
||||
predicate_it.set(enable);
|
||||
++predicate_it;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -301,7 +394,7 @@ template <typename Traits_,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
@ -309,7 +402,7 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
FragmentElementType_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
@ -318,7 +411,7 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
FragmentElementType_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
@ -329,13 +422,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
typedef typename Base::Scalar Scalar;
|
||||
|
||||
/// Fragment element
|
||||
typedef typename Base::FragmentElement FragmentElement;
|
||||
typedef FragmentElement_ FragmentElement;
|
||||
|
||||
/// Specifies in which dimension post-increment accesses advance.
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
@ -364,6 +457,9 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Memory access type
|
||||
typedef typename Base::AccessType AccessType;
|
||||
|
||||
/// The number of scalars accessed per load/store.
|
||||
static int const kAccessSize = Base::kAccessSize;
|
||||
|
||||
/// Fragment definition
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
@ -388,21 +484,80 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// The pointer type
|
||||
typedef Scalar const *Pointer;
|
||||
|
||||
/// Tensor reference for the load iterator
|
||||
typedef TensorRef<Scalar const, 4> TensorRef;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar const *pointer;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : pointer(0){ Base::Params::initialize(); }
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *ptr) : pointer(ptr) { Base::Params::initialize(); }
|
||||
|
||||
/// Constructs with a CompactTensorRef<>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()) {
|
||||
Base::Params::initialize(ref.stride());
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance)
|
||||
: pointer(ptr) {
|
||||
Base::Params::initialize(
|
||||
_stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w)
|
||||
: pointer(ptr) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(TensorRef const &ref) {
|
||||
pointer = ref.data();
|
||||
return Base::Params::initialize(ref.stride());
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage const &storage) {
|
||||
pointer = &storage[0];
|
||||
Base::Params::initialize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr) {
|
||||
pointer = ptr;
|
||||
Base::Params::initialize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
@ -411,10 +566,10 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar const *ptr,
|
||||
Index _stride_d,
|
||||
long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
@ -443,11 +598,13 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
// Predicate initialization
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
/// Initializes a predicate vector using a RegularTilePredicateFunctor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
@ -455,8 +612,23 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
RegularTilePredicateFunctor<typename Traits::Delta>(bounds),
|
||||
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
|
||||
}
|
||||
|
||||
/// Initializes a predicate vector using an arbitrary predicate functor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator,
|
||||
/// Functor computing predicates
|
||||
typename PredicateFunctor>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
PredicateFunctor const &functor,
|
||||
Coord<3> const &block_offset) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
functor,
|
||||
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
|
||||
}
|
||||
|
||||
//
|
||||
@ -475,41 +647,27 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
Index block_offset_h = 0;
|
||||
Index block_offset_w = 0;
|
||||
if (kAdvance == IteratorAdvance::kH) {
|
||||
block_offset_h = block_offset[1];
|
||||
block_offset_w = block_offset[2];
|
||||
} else {
|
||||
block_offset_h = block_offset[2];
|
||||
block_offset_w = block_offset[1];
|
||||
}
|
||||
Index pointer_offset = Index((block_offset[0] + thread_offset[0]) * params.stride_d) +
|
||||
Index((block_offset[1] + thread_offset[1]) * params.stride_h) +
|
||||
Index((block_offset[2] + thread_offset[2]) * params.stride_w);
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
(block_offset_h + thread_offset[1]) * params.stride_h +
|
||||
(block_offset_w + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
params.pointer += pointer_offset;
|
||||
}
|
||||
|
||||
/// Constructs a tile load iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileLoadIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Scalar const *ptr,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
params.pointer = ptr + thread_offset_func()[2];
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar const *data() const { return params.pointer; }
|
||||
params.stride_d = 0;
|
||||
params.stride_h = 0;
|
||||
params.stride_w = 1;
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar, Base::kAccessSize, kMemorySpace>::load(value, params.pointer, imm);
|
||||
params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0;
|
||||
}
|
||||
|
||||
/// Increment in the D dimension
|
||||
@ -524,8 +682,21 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
/// Increment in the next dimension
|
||||
CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
|
||||
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Load<Scalar,
|
||||
kAccessSize,
|
||||
kMemorySpace,
|
||||
kFragmentElementType,
|
||||
FragmentElement,
|
||||
Tile::kW,
|
||||
sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
CUTLASS_HOST_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
@ -538,7 +709,27 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE TileLoadIterator & operator+=(Coord<3> const &offset) {
|
||||
long long _offset = offset.template dot<long long>(
|
||||
make_Coord(params.stride_d, params.stride_h, params.stride_w)
|
||||
);
|
||||
|
||||
params.pointer += _offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds a raw offset to the pointer
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
|
||||
CUTLASS_HOST_DEVICE Index stride_advance(void) {
|
||||
Index stride = params.stride_h;
|
||||
if (kAdvance == IteratorAdvance::kW) {
|
||||
stride = params.stride_w;
|
||||
}
|
||||
return stride;
|
||||
}
|
||||
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
@ -547,11 +738,12 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Load<typename Fragment::Element, Tile::kC, kMemorySpace>::load(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
load_element(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
|
||||
}
|
||||
}
|
||||
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
@ -587,6 +779,19 @@ struct TileLoadIterator : public TileIteratorBase<Traits_,
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -626,7 +831,7 @@ template <typename Traits_,
|
||||
MemorySpace::Kind MemorySpace = MemorySpace::kGeneric,
|
||||
typename Index_ = int,
|
||||
typename FragmentElement_ = Scalar_,
|
||||
IteratorFragment::Kind IteratorFragment_ = IteratorFragment::kScalar,
|
||||
FragmentElementType::Kind FragmentElementType_ = FragmentElementType::kScalar,
|
||||
typename Skew_ = Shape<0, 0, 0, 0> >
|
||||
struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
Scalar_,
|
||||
@ -634,7 +839,7 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
FragmentElementType_,
|
||||
Skew_> {
|
||||
/// Base class
|
||||
typedef TileIteratorBase<Traits_,
|
||||
@ -643,7 +848,7 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
MemorySpace,
|
||||
Index_,
|
||||
FragmentElement_,
|
||||
IteratorFragment_,
|
||||
FragmentElementType_,
|
||||
Skew_>
|
||||
Base;
|
||||
|
||||
@ -660,11 +865,14 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
static IteratorAdvance::Kind const kAdvance = Base::kAdvance;
|
||||
|
||||
/// Specifies type of iterator fragment storage (Salar or WmmaMatrix)
|
||||
static IteratorFragment::Kind const kIteratorFragment = Base::kIteratorFragment;
|
||||
static FragmentElementType::Kind const kFragmentElementType = Base::kFragmentElementType;
|
||||
|
||||
/// Source or destination memory space
|
||||
static MemorySpace::Kind const kMemorySpace = Base::kMemorySpace;
|
||||
|
||||
/// The number of scalars accessed per load/store.
|
||||
static int const kAccessSize = Base::kAccessSize;
|
||||
|
||||
/// Index type
|
||||
typedef typename Base::Index Index;
|
||||
|
||||
@ -707,21 +915,71 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
/// IteratorBase parameters
|
||||
typedef typename Base::Params BaseParams;
|
||||
|
||||
/// Pointer to underlying type
|
||||
typedef Scalar *Pointer;
|
||||
|
||||
/// Tensor reference for the store iterator
|
||||
typedef TensorRef<Scalar, 4> TensorRef;
|
||||
|
||||
/// Parameters
|
||||
struct Params : public BaseParams {
|
||||
/// Pointer to memory
|
||||
Scalar *pointer;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() : pointer(0) {}
|
||||
|
||||
// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *ptr) : pointer(ptr) { Base::Params::initialize(); }
|
||||
|
||||
/// Constructs with a CompactTensorRef<>
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(TensorRef const &ref): pointer(ref.data()) {
|
||||
Base::Params::initialize(ref.stride());
|
||||
}
|
||||
|
||||
// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) {
|
||||
initialize(ptr, stride_d, stride_h, stride_w);
|
||||
}
|
||||
|
||||
// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Scalar *ptr,
|
||||
long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance);
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(SharedStorage &storage) {
|
||||
pointer = &storage[0];
|
||||
return 0;
|
||||
return Base::Params::initialize();
|
||||
}
|
||||
|
||||
/// Initialize params to access storage object
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr) {
|
||||
pointer = ptr;
|
||||
return Base::Params::initialize();
|
||||
}
|
||||
|
||||
/// Initializes params to access a raw pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) {
|
||||
int initialize(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) {
|
||||
Base::Params::initialize(stride_d, stride_h, stride_w);
|
||||
pointer = ptr;
|
||||
return 0;
|
||||
@ -730,10 +988,10 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
/// Initializes params
|
||||
CUTLASS_HOST_DEVICE
|
||||
int initialize(Scalar *ptr,
|
||||
Index _stride_d,
|
||||
long long _stride_d,
|
||||
Index _stride_h,
|
||||
Index _stride_w,
|
||||
Index _inc_d,
|
||||
long long _inc_d,
|
||||
Index _inc_h,
|
||||
Index _inc_w,
|
||||
Index _inc_advance) {
|
||||
@ -762,11 +1020,13 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
int stage;
|
||||
|
||||
//
|
||||
// Static member functions
|
||||
// Predicate initialization
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector
|
||||
template <typename PredicateIterator>
|
||||
/// Initializes a predicate vector using a RegularTilePredicateFunctor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
@ -774,8 +1034,23 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
0)) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
bounds,
|
||||
block_offset + make_Coord(0, thread_offset[1], thread_offset[2] * Tile::kC));
|
||||
RegularTilePredicateFunctor<typename Traits::Delta>(bounds),
|
||||
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
|
||||
}
|
||||
|
||||
/// Initializes a predicate vector using an arbitrary predicate functor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator,
|
||||
/// Functor computing predicates
|
||||
typename PredicateFunctor>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
PredicateFunctor const &functor,
|
||||
Coord<3> const &block_offset) {
|
||||
Base::initialize_predicates(
|
||||
predicate_it,
|
||||
functor,
|
||||
block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2]));
|
||||
}
|
||||
|
||||
//
|
||||
@ -794,25 +1069,22 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
: params(_params), stage(0) {
|
||||
thread_offset = thread_offset_func();
|
||||
|
||||
params.pointer += block_offset[0] * params.stride_d +
|
||||
params.pointer += (block_offset[0] + thread_offset[0]) * params.stride_d +
|
||||
(block_offset[1] + thread_offset[1]) * params.stride_h +
|
||||
(block_offset[2] + thread_offset[2] * Tile::kC) / Tile::kC * params.stride_w;
|
||||
(block_offset[2] + thread_offset[2]) * params.stride_w;
|
||||
}
|
||||
|
||||
/// Constructs a tile store iterator
|
||||
CUTLASS_HOST_DEVICE
|
||||
TileStoreIterator(Params const &,
|
||||
SharedStorage &shared_storage,
|
||||
Coord<3> const &block_offset = make_Coord(0, 0, 0),
|
||||
ThreadOffset thread_offset_func = ThreadOffset())
|
||||
TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func = ThreadOffset())
|
||||
: stage(0) {
|
||||
int const offset = thread_offset_func()[2];
|
||||
params.pointer = &shared_storage[offset];
|
||||
}
|
||||
params.pointer = ptr + thread_offset_func()[2];
|
||||
params.stride_d = 0;
|
||||
params.stride_h = 0;
|
||||
params.stride_w = 1;
|
||||
|
||||
/// Returns the current pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
Scalar *data() const { return params.pointer; }
|
||||
params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0;
|
||||
}
|
||||
|
||||
/// Increment in the D dimension
|
||||
CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
|
||||
@ -827,7 +1099,7 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
CUTLASS_HOST_DEVICE void inc_advance() {}
|
||||
|
||||
/// Increment the stage.
|
||||
CUTLASS_DEVICE void inc_stage() {
|
||||
CUTLASS_HOST_DEVICE void inc_stage() {
|
||||
if (Tile::kD > 1) {
|
||||
int const kStageSize = Tile::kH * Tile::kW * Tile::kC;
|
||||
if (stage == Tile::kD - 1) {
|
||||
@ -840,25 +1112,43 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
}
|
||||
}
|
||||
|
||||
/// The accessor.
|
||||
CUTLASS_DEVICE void set(AccessType const &value, int d, int h, int w, int c) {
|
||||
int const imm =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar, Base::kAccessSize, kMemorySpace>::store(value, params.pointer, imm);
|
||||
/// Adds a vector offset to the iterator
|
||||
CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord<3> const &offset) {
|
||||
params.pointer += offset.template dot<long long>(
|
||||
make_Coord(params.stride_d, params.stride_h, params.stride_w)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Adds a raw offset to the pointer
|
||||
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
|
||||
|
||||
/// Stores a single fragment element into memory.
|
||||
CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
Store<Scalar,
|
||||
kAccessSize,
|
||||
kMemorySpace,
|
||||
kFragmentElementType,
|
||||
FragmentElement,
|
||||
Tile::kW,
|
||||
sizeof(FragmentElement) * kAccessSize>::store(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
|
||||
FragmentConstIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
if (*pred_it) {
|
||||
Store<typename Fragment::Element, Tile::kC, kMemorySpace>::store(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, 0)), data(), 0);
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
store_element(
|
||||
reinterpret_cast<AccessType const &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
|
||||
}
|
||||
}
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
@ -877,23 +1167,103 @@ struct TileStoreIterator : public TileIteratorBase<Traits_,
|
||||
|
||||
/// Stores a fragment and advances to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment) {
|
||||
CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
|
||||
TileStoreIterator _store_it(*this);
|
||||
_store_it.store_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Stores a fragment without advancing the iterator.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void store(Fragment &fragment) const {
|
||||
CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
store(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a single fragment element from memory
|
||||
CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const {
|
||||
int const offset =
|
||||
ComputeOffsetFromStrides<typename Base::ImmediateOffsetStrides>::get(d, h, w, c);
|
||||
|
||||
Load<Scalar,
|
||||
kAccessSize,
|
||||
kMemorySpace,
|
||||
kFragmentElementType,
|
||||
FragmentElement,
|
||||
Tile::kW,
|
||||
sizeof(FragmentElement) * kAccessSize>::load(value, params.pointer, offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
|
||||
for (int d = 0; d < Iterations::kD; ++d) {
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w, ++pred_it) {
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
if (*pred_it) {
|
||||
load_element(
|
||||
reinterpret_cast<AccessType &>(frag_iterator.at(d, h, w, c)), d, h, w, c);
|
||||
}
|
||||
}
|
||||
if (w < Iterations::kW - 1) {
|
||||
inc_w();
|
||||
}
|
||||
}
|
||||
if (h < Iterations::kH - 1) {
|
||||
inc_h();
|
||||
}
|
||||
}
|
||||
if (d < Iterations::kD - 1) {
|
||||
inc_d();
|
||||
}
|
||||
}
|
||||
inc_advance();
|
||||
}
|
||||
|
||||
/// Loads a fragment and advances the iterator to the next tile.
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
TileStoreIterator _load_it(*this);
|
||||
_load_it.load_post_increment(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment) const {
|
||||
typename PredicateVector::TrivialIterator pred_it;
|
||||
load(fragment, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment without advancing the iterator..
|
||||
template <typename Fragment>
|
||||
CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) {
|
||||
FragmentIterator frag_iterator(fragment);
|
||||
for (int h = 0; h < Iterations::kH; ++h) {
|
||||
for (int w = 0; w < Iterations::kW; ++w) {
|
||||
for (int c = 0; c < Iterations::kC; ++c) {
|
||||
load_element(reinterpret_cast<AccessType &>(frag_iterator.at(0, h, w, c)), d, h, w, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
378
cutlass/tile_stream.h
Normal file
@ -0,0 +1,378 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Implements the tile stream concept, composing an iterator with a transformation. Offers
|
||||
split-phase semantics, separating the initiation of an asynchronous memory operation with a
|
||||
fence forcing it to complete.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/convert.h"
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic stream for loading and transforming fragments
|
||||
template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
struct TileLoadStream {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// TileLoadIterator
|
||||
typedef Iterator_ Iterator;
|
||||
|
||||
/// Transformer
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Fragment fetched from source memory
|
||||
typedef typename Iterator::Fragment Fragment;
|
||||
|
||||
/// Output fragment from transformer
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
|
||||
/// Tensor reference expected by the stream
|
||||
typedef typename Iterator::TensorRef TensorRef;
|
||||
|
||||
/// Empty predicate vector struct
|
||||
struct PredicateVector {};
|
||||
|
||||
/// Index type
|
||||
typedef typename Iterator::Index Index;
|
||||
|
||||
/// Parameters object used to construct generic load stream
|
||||
struct Params {
|
||||
/// Parameters to the iterator
|
||||
typename Iterator::Params iterator;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Constructor with iterator params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load tiles
|
||||
Iterator iterator;
|
||||
|
||||
/// Fragment loaded via iterator
|
||||
Fragment fetched_fragment;
|
||||
|
||||
/// Transformation applied to fragments
|
||||
Transformer transformer;
|
||||
|
||||
/// Transformed fragment from transformer
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
TileLoadStream(Params const &_params, TensorRef const &_ref)
|
||||
: iterator(_params.iterator, _ref) {}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
TileLoadStream(Params const &_params,
|
||||
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
|
||||
): iterator(_params.iterator, threadblock_offset) { }
|
||||
|
||||
/// Loads a tile and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy() { iterator.load_post_increment(fetched_fragment); }
|
||||
|
||||
/// Commits the fetched fragment and applies a transformation
|
||||
CUTLASS_DEVICE
|
||||
void commit() { transformer.transform(fetched_fragment, transformed_fragment); }
|
||||
|
||||
/// Accesses the loaded, transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
Fragment &intermediate_fragment() { return fetched_fragment; }
|
||||
|
||||
/// Accesses the loaded, transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &fragment() { return transformed_fragment; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic stream for transforming and storing fragments
|
||||
template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
struct TileStoreStream {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// TileLoadIterator
|
||||
typedef Iterator_ Iterator;
|
||||
|
||||
/// Transformer
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Source fragment
|
||||
typedef typename Transformer::InputFragment Fragment;
|
||||
|
||||
/// Transformed fragment, compatible with Iterator::Fragment
|
||||
typedef typename Transformer::OutputFragment TransformedFragment;
|
||||
|
||||
/// Tensor reference expected by the underlying iterator
|
||||
typedef typename Iterator::TensorRef TensorRef;
|
||||
|
||||
/// Empty predicate vector struct
|
||||
struct PredicateVector {};
|
||||
|
||||
/// Index type
|
||||
typedef typename Iterator::Index Index;
|
||||
|
||||
/// Parameters used to construct the stream
|
||||
struct Params {
|
||||
/// Parameters to the iterator
|
||||
typename Iterator::Params iterator;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Constructor with iterator params
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
|
||||
};
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to store tiles
|
||||
Iterator iterator;
|
||||
|
||||
/// Transformation applied to inputs
|
||||
Transformer transformer;
|
||||
|
||||
/// Source fragment
|
||||
Fragment source_fragment;
|
||||
|
||||
/// Transformed fragment from transformer
|
||||
TransformedFragment transformed_fragment;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
TileStoreStream(Params const &_params, TensorRef const &_ref)
|
||||
: iterator(_params.iterator, _ref) {}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
TileStoreStream(Params const &_params,
|
||||
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
|
||||
): iterator(_params.iterator, threadblock_offset) { }
|
||||
|
||||
/// Stores a fragment and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy() {
|
||||
|
||||
transformer.transform(source_fragment, transformed_fragment);
|
||||
iterator.store_post_increment(transformed_fragment);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy(Fragment const &frag) {
|
||||
source_fragment = frag;
|
||||
copy();
|
||||
}
|
||||
|
||||
/// Commits the store operation
|
||||
CUTLASS_DEVICE
|
||||
void commit() {}
|
||||
|
||||
/// Accesses the transformed fragment
|
||||
CUTLASS_DEVICE
|
||||
Fragment &fragment() { return source_fragment; }
|
||||
|
||||
/// Accesses the fragment after trasnforming
|
||||
CUTLASS_DEVICE
|
||||
TransformedFragment &intermediate_fragment() { return transformed_fragment; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic stream for loading and transforming fragments
|
||||
template <typename Iterator_,
|
||||
typename PredicateFunctor_ =
|
||||
RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
struct PredicatedTileLoadStream : public TileLoadStream<Iterator_, Transformer_> {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
typedef TileLoadStream<Iterator_, Transformer_> Base;
|
||||
|
||||
/// TileLoadIterator
|
||||
typedef Iterator_ Iterator;
|
||||
|
||||
/// Predicate functor
|
||||
typedef PredicateFunctor_ PredicateFunctor;
|
||||
|
||||
/// Transformer
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Fragment fetched from source memory
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Output fragment from transformer
|
||||
typedef typename Base::TransformedFragment TransformedFragment;
|
||||
|
||||
/// Parameters object used to construct generic load stream
|
||||
typedef typename Base::Params Params;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Predicates
|
||||
typename Iterator::PredicateVector predicates;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
PredicatedTileLoadStream(Params const &_params,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
|
||||
: Base(_params, threadblock_offset) {
|
||||
this->iterator.initialize_predicates(
|
||||
predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
|
||||
}
|
||||
|
||||
/// Loads a tile and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy() { this->iterator.load_post_increment(this->fetched_fragment, predicates.begin()); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Generic stream for transforming and storing fragments
|
||||
template <typename Iterator_,
|
||||
typename PredicateFunctor_ =
|
||||
RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
|
||||
typename Transformer_ = Copy<typename Iterator_::Fragment> >
|
||||
struct PredicatedTileStoreStream : public TileStoreStream<Iterator_, Transformer_> {
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
typedef TileStoreStream<Iterator_, Transformer_> Base;
|
||||
|
||||
/// TileLoadIterator
|
||||
typedef Iterator_ Iterator;
|
||||
|
||||
/// Predicate functor
|
||||
typedef PredicateFunctor_ PredicateFunctor;
|
||||
|
||||
/// Transformer
|
||||
typedef Transformer_ Transformer;
|
||||
|
||||
/// Fragment fetched from source memory
|
||||
typedef typename Base::Fragment Fragment;
|
||||
|
||||
/// Output fragment from transformer
|
||||
typedef typename Base::TransformedFragment TransformedFragment;
|
||||
|
||||
/// Parameters object used to construct generic load stream
|
||||
typedef typename Base::Params Params;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Predicates
|
||||
typename Iterator::PredicateVector predicates;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
PredicatedTileStoreStream(Params const &_params,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
|
||||
: Base(_params, threadblock_offset) {
|
||||
this->iterator.initialize_predicates(
|
||||
predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
|
||||
}
|
||||
|
||||
/// Stores the fragment and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy() {
|
||||
this->transformer.transform(this->source_fragment, this->transformed_fragment);
|
||||
this->iterator.store_post_increment(this->transformed_fragment, predicates.begin());
|
||||
}
|
||||
|
||||
/// Stores the fragment and increments the iterator
|
||||
CUTLASS_DEVICE
|
||||
void copy(Fragment const &frag) {
|
||||
this->source_fragment = frag;
|
||||
copy();
|
||||
}
|
||||
|
||||
/// Commits the store operation
|
||||
CUTLASS_DEVICE
|
||||
void commit() {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
// clang-format on
|
||||
@ -28,7 +28,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/tile_iterator.h>
|
||||
#include "cutlass/tile_iterator.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -204,6 +204,9 @@ struct TileTraitsStandard {
|
||||
/// Number of participating warps
|
||||
static int const kWarpCount = kThreads / kWarpSize;
|
||||
|
||||
/// By default, do not do scalar loads
|
||||
static int const kAccessSize = 1;
|
||||
|
||||
// Static assertions
|
||||
static_assert(!(ShapeCount<Tile>::kDhw % kThreads),
|
||||
"Tiling undefined if elements not divisible by threads.");
|
||||
@ -223,8 +226,7 @@ struct TileTraitsStandard {
|
||||
typedef typename Traits::Delta Delta;
|
||||
|
||||
/// Delta between each thread's access
|
||||
/// TODO MTA this is wrong for sure, but Delta is used for stride computation at the moment
|
||||
typedef Delta ImmediateOffsetStrides;
|
||||
typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides;
|
||||
|
||||
/// Number of accesses
|
||||
typedef typename Traits::Iterations Iterations;
|
||||
|
||||
457
cutlass/util/complex.h
Normal file
@ -0,0 +1,457 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cuComplex.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <iosfwd>
|
||||
|
||||
namespace cutlass {
|
||||
namespace platform {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//
|
||||
// Accessors for CUDA complex types
|
||||
//
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
float const &real(cuFloatComplex const &z) { return z.x; }
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
float &real(cuFloatComplex &z) { return z.x; }
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
double const &real(cuDoubleComplex const &z) { return z.x; }
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
double &real(cuDoubleComplex &z) { return z.x; }
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
float const &imag(cuFloatComplex const &z) { return z.y; }
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
float &imag(cuFloatComplex &z) { return z.y; }
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
double const &imag(cuDoubleComplex const &z) { return z.y; }
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
double &imag(cuDoubleComplex &z) { return z.y; }
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Class for representing and manipulating complex numbers with conversions from built-in CUDA
|
||||
/// complex types.
|
||||
template <typename T>
|
||||
class complex {
|
||||
public:
|
||||
/// Type alias for scalar type
|
||||
typedef T value_type;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Real part
|
||||
T _real;
|
||||
|
||||
/// Imaginary part
|
||||
T _imag;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(T r = T(0), T i = T(0)) : _real(r), _imag(i) {}
|
||||
|
||||
/// Conversion from cuFloatComplex
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(cuFloatComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
|
||||
|
||||
/// Conversion from cuDoubleComplex
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
complex(cuDoubleComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
|
||||
|
||||
/// Accesses the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
T const &real() const { return _real; }
|
||||
|
||||
/// Accesses the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
T &real() { return _real; }
|
||||
|
||||
/// Accesses the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
T const &imag() const { return _imag; }
|
||||
|
||||
/// Accesses the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
T &imag() { return _imag; }
|
||||
|
||||
/// Converts to cuFloatComplex
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator cuFloatComplex() const { return make_cuFloatComplex(real(), imag()); }
|
||||
|
||||
/// Converts to cuDoubleComplex
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
CUTLASS_HOST_DEVICE
|
||||
operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); }
|
||||
};
|
||||
|
||||
//
|
||||
// Accessors for complex template
|
||||
//
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T const &real(complex<T> const &z) {
|
||||
return z.real();
|
||||
}
|
||||
|
||||
/// Returns the real part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T &real(complex<T> &z) {
|
||||
return z.real();
|
||||
}
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T const &imag(complex<T> const &z) {
|
||||
return z.imag();
|
||||
}
|
||||
|
||||
/// Returns the imaginary part of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T &imag(complex<T> &z) {
|
||||
return z.imag();
|
||||
}
|
||||
|
||||
//
|
||||
// Output operators
|
||||
//
|
||||
|
||||
template <typename T>
|
||||
std::ostream &operator<<(std::ostream &out, complex<T> const &z) {
|
||||
T _r = real(z);
|
||||
T _i = imag(z);
|
||||
return out << _r << "+i" << _i;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-member operators defined for complex types
|
||||
//
|
||||
|
||||
/// Equality operator
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE bool operator==(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
return real(lhs) == (rhs) && imag(lhs) == imag(rhs);
|
||||
}
|
||||
|
||||
/// Inequality operator
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE bool operator!=(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
/// Addition
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator+(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
return complex<T>(real(lhs) + real(rhs), imag(lhs) + imag(rhs));
|
||||
}
|
||||
|
||||
/// Subtraction
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator-(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
return complex<T>(real(lhs) - real(rhs), imag(lhs) - imag(rhs));
|
||||
}
|
||||
|
||||
/// Multiplication
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator*(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
return complex<T>(real(lhs) * real(rhs) - imag(lhs) * imag(rhs),
|
||||
real(lhs) * imag(rhs) + imag(lhs) * real(rhs));
|
||||
}
|
||||
|
||||
/// Scalar Multiplication
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator*(complex<T> const &lhs, T const &s) {
|
||||
return complex<T>(real(lhs) * s, imag(lhs) * s);
|
||||
}
|
||||
|
||||
/// Scalar Multiplication
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator*(T const &s, complex<T> const &rhs) {
|
||||
return complex<T>(s * real(rhs), s * imag(rhs));
|
||||
}
|
||||
|
||||
/// Division
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator/(complex<T> const &lhs, complex<T> const &rhs) {
|
||||
T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
|
||||
|
||||
return complex<T>((real(lhs) * (rhs) + imag(lhs) * imag(rhs)) / d,
|
||||
(imag(lhs) * (rhs)-real(lhs) * imag(rhs)) / d);
|
||||
}
|
||||
|
||||
/// Scalar Division
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator/(complex<T> const &lhs, T const &s) {
|
||||
return complex<T>(real(lhs) / s, imag(lhs) / s);
|
||||
}
|
||||
|
||||
/// Scalar divided by complex
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> operator/(T const &s, complex<T> const &rhs) {
|
||||
T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
|
||||
|
||||
return complex<T>((s * (rhs)) / d, -(s * imag(rhs)) / d);
|
||||
}
|
||||
|
||||
/// Addition
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> &operator+=(complex<T> &lhs, complex<T> const &rhs) {
|
||||
lhs = (lhs + rhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
/// Subtraction
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> &operator-=(complex<T> &lhs, complex<T> const &rhs) {
|
||||
lhs = (lhs - rhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
/// Multiplication
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<T> &lhs, complex<T> const &rhs) {
|
||||
lhs = (lhs * rhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
/// Scalar multiplication
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> &operator*=(complex<T> &lhs, T s) {
|
||||
lhs = (lhs * s);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
/// Division
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> &operator/=(complex<T> &lhs, complex<T> const &rhs) {
|
||||
lhs = (lhs / rhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
//
|
||||
// Non-member functions defined for complex numbers
|
||||
//
|
||||
|
||||
/// Returns the magnitude of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T abs(complex<T> const &z) {
|
||||
return sqrt(norm(z));
|
||||
}
|
||||
|
||||
/// Returns the magnitude of the complex number
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T arg(complex<T> const &z) {
|
||||
return atan2(imag(z), real(z));
|
||||
}
|
||||
|
||||
/// Returns the squared magnitude
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE T norm(complex<T> const &z) {
|
||||
return real(z) * real(z) + imag(z) * imag(z);
|
||||
}
|
||||
|
||||
/// Returns the complex conjugate
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> conj(complex<T> const &z) {
|
||||
return complex<T>(real(z), -imag(z));
|
||||
}
|
||||
|
||||
/// Projects the complex number z onto the Riemann sphere
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> proj(complex<T> const &z) {
|
||||
T d = real(z) * real(z) + imag(z) * imag(z) + T(1);
|
||||
return complex<T>((T(2) * real(z)) / d, (T(2) * imag(z)) / d);
|
||||
}
|
||||
|
||||
/// Returns a complex number with magnitude r and phase theta
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
|
||||
return complex<T>(r * cos(theta), r * sin(theta));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> exp(complex<T> const &z) {
|
||||
return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log(complex<T> const &z) {
|
||||
return complex<T>(log(abs(z)), arg(z));
|
||||
}
|
||||
|
||||
/// Computes the complex exponential of z.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> log10(complex<T> const &z) {
|
||||
return log(z) / T(log(T(10)));
|
||||
}
|
||||
|
||||
/// Computes the square root of complex number z
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> sqrt(complex<T> const &z) {
|
||||
return sqrt(T(2)) / T(2) *
|
||||
complex<T>(sqrt(sqrt(norm(z)) + real(z)),
|
||||
(imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z)));
|
||||
}
|
||||
|
||||
/// Computes the cosine of complex z.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> cos(complex<T> const &z) {
|
||||
return (exp(z) + exp(-z)) / T(2);
|
||||
}
|
||||
|
||||
/// Computes the sin of complex z.
|
||||
#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
|
||||
// host-only type
|
||||
template <typename T>
|
||||
CUTLASS_HOST_DEVICE complex<T> sin(complex<T> const &z) {
|
||||
return (exp(-z) - exp(z)) * complex<T>(T(0), T(1) / T(2));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace platform
|
||||
} // namespace cutlass
|
||||
@ -30,7 +30,7 @@
|
||||
* \brief Math utilities
|
||||
*/
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -128,4 +128,38 @@ CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
|
||||
return temp ? (a / temp * b) : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* log2 computation, what's the
|
||||
* difference between the below codes and
|
||||
* log2_up/down codes?
|
||||
*/
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t clz(value_t x) {
|
||||
for (int i = 31; i >= 0; --i) {
|
||||
if ((1 << i) & x) return 31 - i;
|
||||
}
|
||||
return 32;
|
||||
}
|
||||
|
||||
template <typename value_t>
|
||||
CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
|
||||
int a = 31 - clz(x);
|
||||
a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
|
||||
return a;
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Min/Max
|
||||
******************************************************************************/
|
||||
|
||||
template <int A, int B>
|
||||
struct Min {
|
||||
static int const kValue = (A < B) ? A : B;
|
||||
};
|
||||
|
||||
template <int A, int B>
|
||||
struct Max {
|
||||
static int const kValue = (A > B) ? A : B;
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -22,27 +22,26 @@
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defies functors for mapping blockIdx to partitions of the GEMM computation.
|
||||
|
||||
Currently, we only implement an identity mapping.
|
||||
/*!
|
||||
\file
|
||||
\brief
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct IdentityBlockSwizzle {
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE IdentityBlockSwizzle() {}
|
||||
//
|
||||
// Definitions for 1-bit binary and 4-bit integer types
|
||||
//
|
||||
|
||||
/// Swizzle the block index.
|
||||
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
||||
};
|
||||
struct bin1_t {}; // 1-bit binary type
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct int4_t {}; // 4-bit signed integer type
|
||||
|
||||
struct uint4_t {}; // 4-bit unsigned integer type
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -110,9 +110,17 @@
|
||||
#include <type_traits> // For integral constants, conditional metaprogramming, and type traits
|
||||
#endif
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#endif
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// OS
|
||||
//-----------------------------------------------------------------------------
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__)
|
||||
#define CUTLASS_OS_WINDOWS
|
||||
#endif
|
||||
|
||||
/******************************************************************************
|
||||
* Macros
|
||||
******************************************************************************/
|
||||
|
||||
170
cutlass/vector.h
@ -31,7 +31,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <cutlass/util/platform.h>
|
||||
#include "cutlass/util/numeric_types.h"
|
||||
#include "cutlass/util/platform.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -80,13 +81,43 @@ union Vector {
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; }
|
||||
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; }
|
||||
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
union Vector<half, 1> {
|
||||
/// The scalar type.
|
||||
typedef half Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = 1 };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes * (int)sizeof(Scalar) };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
// Make sure that the vector type makes sense.
|
||||
static_assert(kVectorSize <= 16, "Vector type is too large");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The associated array of scalars.
|
||||
uint16_t scalars[kLanes];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const {
|
||||
return reinterpret_cast<Scalar const&>(scalars[i]);
|
||||
}
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) {
|
||||
return reinterpret_cast<Scalar&>(scalars[i]);
|
||||
}
|
||||
};
|
||||
|
||||
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
|
||||
|
||||
template <int kLanes_>
|
||||
@ -112,19 +143,124 @@ union Vector<half, kLanes_> {
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const {
|
||||
CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const {
|
||||
return reinterpret_cast<Scalar const&>(scalars[i]);
|
||||
}
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return reinterpret_cast<Scalar&>(scalars[i]); }
|
||||
CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) {
|
||||
return reinterpret_cast<Scalar&>(scalars[i]);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Vector definition for 1-bit binary datatype
|
||||
template <int kLanes_>
|
||||
union Vector<bin1_t, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef bin1_t Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes / 8 };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
static_assert((kLanes >= 8) && !(kLanes % 8),
|
||||
"May only construct vectors of bin1_t that are multiples of 8 bits.");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Default Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Vector() {}
|
||||
/// Constructor to convert from uint32_t type
|
||||
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE bool operator[](uint32_t i) const {
|
||||
return ( (registers[i / 32] & (1 << (i % 32))) != 0 );
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Vector definition for 4-bit signed integer datatype
|
||||
template <int kLanes_>
|
||||
union Vector<int4_t, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef int4_t Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes / 2 };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
static_assert((kLanes >= 2) && !(kLanes % 2),
|
||||
"May only construct vectors of int4_t that are multiples of 8 bits.");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Default Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Vector() {}
|
||||
/// Constructor to convert from uint32_t type
|
||||
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
|
||||
return (registers[i / 8] >> (i % 8 * 4) & 0x0f)
|
||||
- 16 * (registers[i / 8] >> (i % 8 * 4 + 3) & 0x01);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Vector definition for 4-bit unsigned integer datatype
|
||||
template <int kLanes_>
|
||||
union Vector<uint4_t, kLanes_> {
|
||||
/// The scalar type.
|
||||
typedef uint4_t Scalar;
|
||||
|
||||
/// The number of elements in the vector.
|
||||
enum { kLanes = kLanes_ };
|
||||
/// The size of the vector.
|
||||
enum { kVectorSize = kLanes / 2 };
|
||||
/// The number of registers needed to store the vector.
|
||||
enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 };
|
||||
|
||||
static_assert((kLanes >= 2) && !(kLanes % 2),
|
||||
"May only construct vectors of uint4_t that are multiples of 8 bits.");
|
||||
|
||||
/// The aligned storage to make sure we have good alignment.
|
||||
AlignedStruct<kVectorSize> aligned_;
|
||||
/// The data in registers.
|
||||
uint32_t registers[kRegisters];
|
||||
|
||||
/// Default Constructor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Vector() {}
|
||||
/// Constructor to convert from uint32_t type
|
||||
CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; }
|
||||
/// Accessor to the ith lane.
|
||||
CUTLASS_HOST_DEVICE int operator[](uint32_t i) const {
|
||||
return registers[i / 8] >> (i % 8 * 4) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_>
|
||||
CUTLASS_DEVICE void make_zero(Scalar_& x) {
|
||||
CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) {
|
||||
x = Scalar_(0);
|
||||
}
|
||||
|
||||
@ -137,15 +273,29 @@ struct Vectorize {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Element_>
|
||||
struct Vectorize<Element_, 1> {
|
||||
typedef Element_ Type;
|
||||
template <int kLanes_>
|
||||
struct Vectorize<Vector<bin1_t, 32>, kLanes_> {
|
||||
typedef Vector<bin1_t, kLanes_ * 32> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kLanes_>
|
||||
struct Vectorize<Vector<int4_t, 8>, kLanes_> {
|
||||
typedef Vector<int4_t, kLanes_ * 8> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kLanes_>
|
||||
struct Vectorize<Vector<uint4_t, 8>, kLanes_> {
|
||||
typedef Vector<uint4_t, kLanes_ * 8> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Scalar_, int kLanes_>
|
||||
CUTLASS_DEVICE void make_zero(Vector<Scalar_, kLanes_>& vec) {
|
||||
CUTLASS_HOST_DEVICE void make_zero(Vector<Scalar_, kLanes_>& vec) {
|
||||
for (int i = 0; i < Vector<Scalar_, kLanes_>::kRegisters; ++i) {
|
||||
vec.registers[i] = 0;
|
||||
}
|
||||
|
||||
@ -28,20 +28,23 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
|
||||
|
||||
// Dependent header files should use the following macro to guard all code using
|
||||
// nvcuda::wmma:: to enable compilation for CUDA Compute Capabilities < sm_70.
|
||||
// Earlier shader models not support Tensor Cores.
|
||||
#define CUTLASS_USE_WMMA_API
|
||||
|
||||
#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
|
||||
#define CUTLASS_USE_SUBBYTE_WMMA
|
||||
#endif
|
||||
|
||||
#include "stdio.h"
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ >= 10
|
||||
#include <mma.h>
|
||||
#else
|
||||
#include <crt/mma.h>
|
||||
#include <cutlass/fragment.h>
|
||||
#include <cutlass/load_store.h>
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <cutlass/shape.h>
|
||||
#include <cutlass/vector.h>
|
||||
#endif
|
||||
#include "cutlass/fragment.h"
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
@ -61,6 +64,34 @@ struct WmmaLayout<MatrixLayout::kRowMajor> {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Statically maps cutlass types to nvcuda::wmma datatypes
|
||||
template <typename Type_>
|
||||
struct WmmaDataType{
|
||||
typedef Type_ Type;
|
||||
};
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
/// Statically maps cutlass::Vector<bin1_t, 32> to nvcuda::wmma::experimental::precision::b1
|
||||
template<>
|
||||
struct WmmaDataType<Vector<bin1_t, 32> > {
|
||||
typedef nvcuda::wmma::experimental::precision::b1 Type;
|
||||
};
|
||||
|
||||
/// Statically maps cutlass::Vector<int4_t, 8> to nvcuda::wmma::experimental::precision::s4
|
||||
template<>
|
||||
struct WmmaDataType<Vector<int4_t, 8> > {
|
||||
typedef nvcuda::wmma::experimental::precision::s4 Type;
|
||||
};
|
||||
|
||||
/// Statically maps cutlass::Vector<uint4_t, 8> to nvcuda::wmma::experimental::precision::u4
|
||||
template<>
|
||||
struct WmmaDataType<Vector<uint4_t, 8> > {
|
||||
typedef nvcuda::wmma::experimental::precision::u4 Type;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Adapter to nvcuda::wmma fragment load and store operations
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
@ -81,7 +112,7 @@ struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
typename WmmaDataType<Scalar_>::Type,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
@ -117,7 +148,7 @@ struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
|
||||
WmmaShape_::kH,
|
||||
WmmaShape_::kD,
|
||||
/// The scalar.
|
||||
Scalar_,
|
||||
typename WmmaDataType<Scalar_>::Type,
|
||||
/// The layout.
|
||||
typename WmmaLayout<kLayout_>::Layout> {
|
||||
/// This type.
|
||||
@ -188,6 +219,18 @@ struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
// WmmaMatrix cannot be used in a Union and thus in cannot be used in our Vector implementation.
|
||||
// The only use of WmmaMatrix in in combination with Vectorize has kLanes == 1. Due to this it is
|
||||
// safe to keep the Vector->Scalar conversion for WmmaMatrix.
|
||||
template <GemmOperand::Kind kOperand_,
|
||||
MatrixLayout::Kind kLayout_,
|
||||
typename Scalar_,
|
||||
typename WmmaShape_>
|
||||
struct Vectorize<WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>, 1> {
|
||||
typedef WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_> Type;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
}
|
||||
|
||||
#endif // defined CUTLASS_USE_WMMA_API
|
||||
|
||||
150
cutlass/zip_fragment.h
Normal file
@ -0,0 +1,150 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Models a pair of fragments
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/shape.h"
|
||||
#include "cutlass/util/cutlass_math.h"
|
||||
#include "cutlass/vector.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* @brief A template defining \ref fragment_concept
|
||||
* @concept{fragment_concept}
|
||||
*/
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipFragment {
|
||||
/// First fragment object
|
||||
typedef First_ First;
|
||||
|
||||
/// Second fragment object
|
||||
typedef Second_ Second;
|
||||
|
||||
/// This class.
|
||||
typedef ZipFragment<First, Second> This_;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First fragment object
|
||||
First first;
|
||||
|
||||
/// Second fragment object
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_DEVICE
|
||||
ZipFragment() { }
|
||||
|
||||
/// Copy ctor
|
||||
CUTLASS_DEVICE
|
||||
ZipFragment(First const &_first, Second const &_second): first(_first), second(_second) { }
|
||||
|
||||
/// Clear a fragment.
|
||||
CUTLASS_DEVICE void clear() {
|
||||
first.clear();
|
||||
second.clear();
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to construct a ZipFragment object
|
||||
template <typename First, typename Second>
|
||||
CUTLASS_HOST_DEVICE
|
||||
ZipFragment<First, Second> make_ZipFragment(First const &first, Second const &second) {
|
||||
return ZipFragment<First, Second>(first, second);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Zips two convert operations
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipConvert {
|
||||
/// First convert operator
|
||||
typedef First_ First;
|
||||
|
||||
/// Second convert operator
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Defines the input zip fragment
|
||||
typedef ZipFragment<typename First::InputFragment, typename Second::InputFragment> InputFragment;
|
||||
|
||||
/// Defines the output zip fragment
|
||||
typedef ZipFragment<typename First::OutputFragment, typename Second::OutputFragment>
|
||||
OutputFragment;
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
/// First transformer
|
||||
First first;
|
||||
|
||||
/// Second transformer
|
||||
Second second;
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ZipConvert() {}
|
||||
|
||||
/// Ctor.
|
||||
CUTLASS_DEVICE ZipConvert(First const &_first, Second const &_second): first(_first), second(_second) { }
|
||||
|
||||
/// Transform a fragment.
|
||||
CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
|
||||
first.transform(src.first, dst.first);
|
||||
second.transform(src.second, dst.second);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to construct a ZipConvert object
|
||||
template <typename First, typename Second>
|
||||
CUTLASS_HOST_DEVICE
|
||||
ZipConvert<First, Second> make_ZipConvert(First const &first, Second const &second) {
|
||||
return ZipConvert<First, Second>(first, second);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
77
cutlass/zip_tensor_ref.h
Normal file
@ -0,0 +1,77 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines a structure containing a pair of TensorRef-like objects
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename First_, typename Second_>
|
||||
struct ZipTensorRef {
|
||||
/// First tensor ref
|
||||
typedef First_ First;
|
||||
|
||||
/// Second tensor ref
|
||||
typedef Second_ Second;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First TensorRef
|
||||
First first;
|
||||
|
||||
/// Second TensorRef
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ZipTensorRef() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
ZipTensorRef(First const& _first, Second const& _second) : first(_first), second(_second) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs a ZipTensorRef
|
||||
template <typename First, typename Second>
|
||||
CUTLASS_HOST_DEVICE
|
||||
ZipTensorRef<First, Second> make_ZipTensorRef(First const &first, Second const &second) {
|
||||
return ZipTensorRef<First, Second>(first, second);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
287
cutlass/zip_tile_iterator.h
Normal file
@ -0,0 +1,287 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Constructs an iterator that owns two tile iterator instances
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/zip_tensor_ref.h"
|
||||
#include "cutlass/zip_fragment.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constructs an iterator from a pair of iterators
|
||||
template <typename First_, typename Second_>
|
||||
class ZipTileIterator {
|
||||
public:
|
||||
/// First iterator type
|
||||
typedef First_ First;
|
||||
|
||||
/// Second iterator type
|
||||
typedef Second_ Second;
|
||||
|
||||
/// Params object
|
||||
struct Params {
|
||||
/// Parameters of first iterator
|
||||
typename First::Params first;
|
||||
|
||||
/// Parameters of second iterator
|
||||
typename Second::Params second;
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Constructs a parameters object
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(typename First::Params const &_first, typename Second::Params const &_second)
|
||||
: first(_first), second(_second) {}
|
||||
};
|
||||
|
||||
/// Fragment type
|
||||
typedef ZipFragment<typename First::Fragment, typename Second::Fragment> Fragment;
|
||||
|
||||
/// Predicate vector
|
||||
typedef typename First::PredicateVector PredicateVector;
|
||||
|
||||
/// Index type
|
||||
typedef typename First::Index Index;
|
||||
|
||||
/// Tensor reference
|
||||
typedef ZipTensorRef<
|
||||
typename First::TensorRef,
|
||||
typename Second::TensorRef> TensorRef;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// First iterator
|
||||
First first;
|
||||
|
||||
/// Second iterator
|
||||
Second second;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default constructor
|
||||
CUTLASS_DEVICE
|
||||
ZipTileIterator() {}
|
||||
|
||||
/// Constructs a zip iterator from params
|
||||
CUTLASS_DEVICE
|
||||
ZipTileIterator(Params const &_params, Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
|
||||
: first(_params.first, threadblock_offset), second(_params.second, threadblock_offset) {}
|
||||
|
||||
/// Constructs a zip iterator from iterator instances
|
||||
CUTLASS_DEVICE
|
||||
ZipTileIterator(First const &_first, Second const &_second) : first(_first), second(_second) {}
|
||||
|
||||
/// Constructs a zip iterator from iterator instances
|
||||
CUTLASS_DEVICE
|
||||
ZipTileIterator(TensorRef const &ref) : first(ref.first), second(ref.second) {}
|
||||
|
||||
/// Constructs a zip iterator from iterator instances
|
||||
CUTLASS_DEVICE
|
||||
ZipTileIterator(Params const &_params, TensorRef const &ref):
|
||||
first(_params.first, ref.first), second(_params.second, ref.second) {}
|
||||
|
||||
//
|
||||
// Predicate initialization
|
||||
//
|
||||
|
||||
/// Initializes a predicate vector using a RegularTilePredicateFunctor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
Coord<3> const &bounds,
|
||||
Coord<3> const &block_offset = make_Coord(0,
|
||||
0,
|
||||
0)) {
|
||||
first.initialize_predicates(predicate_it, bounds, block_offset);
|
||||
}
|
||||
|
||||
/// Initializes a predicate vector using an arbitrary predicate functor
|
||||
template <
|
||||
/// Predicate iterator
|
||||
typename PredicateIterator,
|
||||
/// Functor computing predicates
|
||||
typename PredicateFunctor>
|
||||
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
|
||||
PredicateFunctor const &functor,
|
||||
Coord<3> const &block_offset) {
|
||||
first.initialize_predicates(predicate_it, functor, block_offset);
|
||||
}
|
||||
|
||||
//
|
||||
// No predicates
|
||||
//
|
||||
|
||||
/// Loads a fragment and increments without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment) {
|
||||
first.load_post_increment(fragment.first);
|
||||
second.load_post_increment(fragment.second);
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
|
||||
Coord<4> const &offset) {
|
||||
first.load_post_increment(fragment.first, offset);
|
||||
second.load_post_increment(fragment.second, offset);
|
||||
}
|
||||
|
||||
/// Loads a fragment without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load(Fragment &fragment) const {
|
||||
first.load(fragment.first);
|
||||
second.load(fragment.second);
|
||||
}
|
||||
|
||||
/// Loads a fragment without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void load(Fragment &fragment,
|
||||
Coord<4> const &offset) const {
|
||||
first.load(fragment.first, offset);
|
||||
second.load(fragment.second, offset);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment) {
|
||||
first.store_post_increment(fragment.first);
|
||||
second.store_post_increment(fragment.second);
|
||||
}
|
||||
|
||||
/// Stores a fragment and increments without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment,
|
||||
Coord<4> const &offset) {
|
||||
first.store_post_increment(fragment.first, offset);
|
||||
second.store_post_increment(fragment.second, offset);
|
||||
}
|
||||
|
||||
/// Stores a fragment without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment) const {
|
||||
first.store(fragment.first);
|
||||
second.store(fragment.second);
|
||||
}
|
||||
|
||||
/// Stores a fragment without predicates
|
||||
template <typename Fragment>
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment,
|
||||
Coord<4> const &offset) const {
|
||||
first.store(fragment.first, offset);
|
||||
second.store(fragment.second, offset);
|
||||
}
|
||||
|
||||
//
|
||||
// With predication
|
||||
//
|
||||
|
||||
/// Loads a fragment and increments, using predicates
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
|
||||
first.load_post_increment(fragment.first, pred_it);
|
||||
second.load_post_increment(fragment.second, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment with predicates
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
|
||||
first.load(fragment.first, pred_it);
|
||||
second.load(fragment.second, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment and increments, using predicates
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
|
||||
first.store_post_increment(fragment.first, pred_it);
|
||||
second.store_post_increment(fragment.second, pred_it);
|
||||
}
|
||||
|
||||
/// Loads a fragment with predicates
|
||||
template <typename Fragment, typename PredicateIterator>
|
||||
CUTLASS_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
|
||||
first.store(fragment.first, pred_it);
|
||||
second.store(fragment.second, pred_it);
|
||||
}
|
||||
|
||||
//
|
||||
// Advances the iterators
|
||||
//
|
||||
|
||||
/// Increments store iterator to next tile
|
||||
CUTLASS_DEVICE ZipTileIterator &increment(int count = 1) {
|
||||
first.increment(count);
|
||||
second.increment(count);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to next tile
|
||||
CUTLASS_DEVICE ZipTileIterator &operator++() { return increment(); }
|
||||
|
||||
CUTLASS_DEVICE ZipTileIterator &operator+=(int count) { return increment(count); }
|
||||
|
||||
/// Adds a vector offset to the underlying iterators
|
||||
CUTLASS_DEVICE ZipTileIterator &operator+=(Coord<3> const &offset) {
|
||||
first += offset;
|
||||
second += offset;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments store iterator to previous tile
|
||||
CUTLASS_DEVICE ZipTileIterator &decrement(int count = 1) {
|
||||
first.decrement(count);
|
||||
second.decrement(count);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Increments to subsequent tile
|
||||
CUTLASS_DEVICE ZipTileIterator &operator--() { return decrement(); }
|
||||
|
||||
/// Decrements to previous tile
|
||||
CUTLASS_DEVICE ZipTileIterator &operator-=(int count) { return decrement(count); }
|
||||
|
||||
/// Adds an offset to both iterators
|
||||
CUTLASS_DEVICE void add_pointer_offset(Index offset) {
|
||||
first.add_pointer_offset(offset);
|
||||
second.add_pointer_offset(offset);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namspace cutlass
|
||||
38
examples/00_basic_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES
|
||||
basic_gemm.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
00_basic_gemm
|
||||
${EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES}
|
||||
)
|
||||
492
examples/00_basic_gemm/basic_gemm.cu
Normal file
@ -0,0 +1,492 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference
|
||||
matrix multiply kernel to verify its correctness.
|
||||
|
||||
The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes
|
||||
the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes
|
||||
all matrices have column-major layout.
|
||||
|
||||
The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices.
|
||||
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
|
||||
in CUTLASS.
|
||||
|
||||
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
|
||||
|
||||
Aside from defining and launching the SGEMM kernel, this example does not use any other components
|
||||
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
|
||||
prevalent in the CUTLASS unit tests.
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// CUTLASS includes needed for single-precision GEMM kernel
|
||||
//
|
||||
|
||||
// Defines cutlass::gemm::Gemm, the generic Gemm computation template class.
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
// Defines cutlass::gemm::SgemmTraits, the structural components for single-precision GEMM
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object,
|
||||
// and launches it on the CUDA device.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
|
||||
cudaError_t CutlassSgemmNN(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float const *A,
|
||||
int lda,
|
||||
float const *B,
|
||||
int ldb,
|
||||
float beta,
|
||||
float *C,
|
||||
int ldc) {
|
||||
|
||||
// Define type definition for single-precision CUTLASS GEMM with column-major
|
||||
// input matrices and 128x128x8 threadblock tile size.
|
||||
//
|
||||
// Note, GemmTraits<> is a generic template defined for various general matrix product
|
||||
// computations within CUTLASS. It is intended to be maximally flexible, and consequently
|
||||
// it contains numerous template arguments.
|
||||
//
|
||||
// To keep the interface manageable, several helpers are defined for plausible compositions
|
||||
// including the following example for single-precision GEMM. Typical values are used as
|
||||
// default template arguments. See `cutlass/gemm/gemm_traits.h` for more details.
|
||||
//
|
||||
typedef cutlass::gemm::SgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
|
||||
cutlass::Shape<8, 128, 128> // threadblock tile size
|
||||
>
|
||||
GemmTraits;
|
||||
|
||||
// Define a CUTLASS GEMM type from a GemmTraits<> instantiation.
|
||||
typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
|
||||
|
||||
// Construct and initialize CUTLASS GEMM parameters object.
|
||||
//
|
||||
// One of CUTLASS's design patterns is to define parameters objects that are constructible
|
||||
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
|
||||
// and other arguments needed by Gemm and its components.
|
||||
//
|
||||
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
|
||||
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
|
||||
//
|
||||
typename Gemm::Params params;
|
||||
|
||||
int result = params.initialize(
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
alpha, // scalar alpha
|
||||
A, // matrix A operand
|
||||
lda,
|
||||
B, // matrix B operand
|
||||
ldb,
|
||||
beta, // scalar beta
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc
|
||||
);
|
||||
|
||||
if (result) {
|
||||
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
Gemm::launch(params);
|
||||
|
||||
// Return any errors associated with the launch or cudaSuccess if no error.
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// The source code after this point in the file is generic CUDA using the CUDA Runtime API
|
||||
// and simple CUDA kernels to initialize matrices and compute the general matrix product.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Kernel to initialize a matrix with small integers.
|
||||
__global__ void InitializeMatrix_kernel(
|
||||
float *matrix,
|
||||
int ldm,
|
||||
int rows,
|
||||
int columns,
|
||||
int seed = 0) {
|
||||
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
|
||||
if (i < rows && j < columns) {
|
||||
int offset = i + j * ldm;
|
||||
|
||||
// Generate arbitrary elements.
|
||||
int const k = 16807;
|
||||
int const m = 16;
|
||||
float value = float(((offset + seed) * k % m) - m / 2);
|
||||
|
||||
matrix[offset] = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple function to initialize a matrix to arbitrary small integers.
|
||||
cudaError_t InitializeMatrix(float *matrix, int ldm, int rows, int columns, int seed = 0) {
|
||||
|
||||
dim3 block(16, 16);
|
||||
dim3 grid(
|
||||
(rows + block.x - 1) / block.x,
|
||||
(columns + block.y - 1) / block.y
|
||||
);
|
||||
|
||||
InitializeMatrix_kernel<<< grid, block >>>(matrix, ldm, rows, columns, seed);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocates device memory for a matrix then fills with arbitrary small integers.
|
||||
cudaError_t AllocateMatrix(float **matrix, int ldm, int rows, int columns, int seed = 0) {
|
||||
cudaError_t result;
|
||||
|
||||
size_t sizeof_matrix = sizeof(float) * ldm * columns;
|
||||
|
||||
// Allocate device memory.
|
||||
result = cudaMalloc(reinterpret_cast<void **>(matrix), sizeof_matrix);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to allocate matrix: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Clear the allocation.
|
||||
result = cudaMemset(*matrix, 0, sizeof_matrix);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to clear matrix device memory: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Initialize matrix elements to arbitrary small integers.
|
||||
result = InitializeMatrix(*matrix, ldm, rows, columns, seed);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to initialize matrix: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Naive reference GEMM computation.
|
||||
__global__ void ReferenceGemm_kernel(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float const *A,
|
||||
int lda,
|
||||
float const *B,
|
||||
int ldb,
|
||||
float beta,
|
||||
float *C,
|
||||
int ldc) {
|
||||
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int j = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
|
||||
if (i < M && j < N) {
|
||||
float accumulator = 0;
|
||||
|
||||
for (int k = 0; k < K; ++k) {
|
||||
accumulator += A[i + k * lda] * B[k + j * ldb];
|
||||
}
|
||||
|
||||
C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc];
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference GEMM computation.
|
||||
cudaError_t ReferenceGemm(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
float alpha,
|
||||
float const *A,
|
||||
int lda,
|
||||
float const *B,
|
||||
int ldb,
|
||||
float beta,
|
||||
float *C,
|
||||
int ldc) {
|
||||
|
||||
dim3 block(16, 16);
|
||||
dim3 grid(
|
||||
(M + block.x - 1) / block.x,
|
||||
(N + block.y - 1) / block.y
|
||||
);
|
||||
|
||||
ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocate several matrices in GPU device memory and call a single-precision
|
||||
/// CUTLASS GEMM kernel.
|
||||
cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
|
||||
cudaError_t result;
|
||||
|
||||
//
|
||||
// Define several matrices to be used as operands to GEMM kernels.
|
||||
//
|
||||
|
||||
// Compute leading dimensions for each matrix.
|
||||
int lda = M;
|
||||
int ldb = K;
|
||||
int ldc = M;
|
||||
|
||||
// Compute size in bytes of the C matrix.
|
||||
size_t sizeof_C = sizeof(float) * ldc * N;
|
||||
|
||||
// Define pointers to matrices in GPU device memory.
|
||||
float *A;
|
||||
float *B;
|
||||
float *C_cutlass;
|
||||
float *C_reference;
|
||||
|
||||
//
|
||||
// Allocate matrices in GPU device memory with arbitrary seeds.
|
||||
//
|
||||
|
||||
result = AllocateMatrix(&A, lda, M, K, 0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&B, ldb, K, N, 17);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&C_cutlass, ldc, M, N, 101);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
cudaFree(B);
|
||||
return result;
|
||||
}
|
||||
|
||||
result = AllocateMatrix(&C_reference, ldc, M, N, 101);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(A);
|
||||
cudaFree(B);
|
||||
cudaFree(C_cutlass);
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to copy C_cutlass matrix to C_reference: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch CUTLASS GEMM.
|
||||
//
|
||||
|
||||
result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "CUTLASS GEMM kernel failed: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify.
|
||||
//
|
||||
|
||||
// Launch reference GEMM
|
||||
result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Reference GEMM kernel failed: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Copy to host and verify equivalence.
|
||||
std::vector<float> host_cutlass(ldc * N, 0);
|
||||
std::vector<float> host_reference(ldc * N, 0);
|
||||
|
||||
result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to copy CUTLASS GEMM results: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to copy Reference GEMM results: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Free device memory allocations.
|
||||
//
|
||||
|
||||
cudaFree(C_reference);
|
||||
cudaFree(C_cutlass);
|
||||
cudaFree(B);
|
||||
cudaFree(A);
|
||||
|
||||
//
|
||||
// Test for bit equivalence of results.
|
||||
//
|
||||
|
||||
if (host_cutlass != host_reference) {
|
||||
std::cerr << "CUTLASS results incorrect." << std::endl;
|
||||
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to basic_gemm example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 00_basic_gemm <M> <N> <K> <alpha> <beta>
|
||||
//
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
//
|
||||
// Parse the command line to obtain GEMM dimensions and scalar values.
|
||||
//
|
||||
|
||||
// GEMM problem dimensions.
|
||||
int problem[3] = { 128, 128, 128 };
|
||||
|
||||
for (int i = 1; i < argc && i < 4; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
ss >> problem[i - 1];
|
||||
}
|
||||
|
||||
// Scalars used for linear scaling the result of the matrix product.
|
||||
float scalars[2] = { 1, 0 };
|
||||
|
||||
for (int i = 4; i < argc && i < 6; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
ss >> scalars[i - 4];
|
||||
}
|
||||
|
||||
//
|
||||
// Run the CUTLASS GEMM test.
|
||||
//
|
||||
|
||||
cudaError_t result = TestCutlassGemm(
|
||||
problem[0], // GEMM M dimension
|
||||
problem[1], // GEMM N dimension
|
||||
problem[2], // GEMM K dimension
|
||||
scalars[0], // alpha
|
||||
scalars[1] // beta
|
||||
);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
38
examples/01_tensor_view/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_TENSOR_VIEW_SOURCES
|
||||
tensor_view.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
01_tensor_view
|
||||
${EXAMPLES_TENSOR_VIEW_SOURCES}
|
||||
)
|
||||
424
examples/01_tensor_view/tensor_view.cu
Normal file
@ -0,0 +1,424 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates operations using TensorRef<> and TensorView<> as well as their explicit
|
||||
equivalent functionality in CUDA code.
|
||||
|
||||
CUTLASS provides abstractions for interacting with multidimension tensors in device memory.
|
||||
Consequently, we define a hierarchy of pointer-like types for referencing tensors.
|
||||
|
||||
T * - raw pointer to elements of type T
|
||||
|
||||
cutlass::TensorRef<T, Rank> - reference to a tensor of elements of type T and given rank.
|
||||
Includes a mapping function and associated stride vector for
|
||||
accessing elements in linear memory.
|
||||
|
||||
cutlass::TensorView<T, Rank>: - extends TensorRef<> by adding bounds information. This is a
|
||||
public TensorRef<T, Rank> complete mathematical object which may be used as the argument
|
||||
to CUTLASS functions.
|
||||
|
||||
The above provide an identity maping of a logical index space to linear memory. An element
|
||||
at logical coordinate X has an offset computed as follows:
|
||||
|
||||
offset = dot(X, stride)
|
||||
|
||||
where dot() computes the inner product of X and a vector of "strides."
|
||||
|
||||
CUTLASS 1.1 introduces a mapping function and an additional 'rank' to offer a flexible way to
|
||||
map the logical index space of the tensor to memory. The mapping function maps a coordinate
|
||||
of rank R to an index space of rank S. The linear offset is computed as:
|
||||
|
||||
offset = dot( MapFunc(X), stride )
|
||||
|
||||
where stride is a vector of rank S.
|
||||
|
||||
|
||||
The complete template declaration for cutlass::TensorRef<> is as follows.
|
||||
|
||||
template <
|
||||
/// Data type of element stored within tensor
|
||||
typename Storage,
|
||||
|
||||
/// Rank of logical tensor
|
||||
int Rank,
|
||||
|
||||
/// Maps a Coord<Rank> in the logical tensor index space to the internal n-D array
|
||||
typename MapFunc = IdentityTensorMapFunc<Rank>,
|
||||
|
||||
/// Rank of internal n-D array
|
||||
int StorageRank_ = MapFunc::kStorageRank,
|
||||
|
||||
/// Index type used for coordinates
|
||||
typename Index = int,
|
||||
|
||||
/// Index type used for offsets and pointer differences
|
||||
typename LongIndex = long long
|
||||
>
|
||||
class TensorRef;
|
||||
|
||||
|
||||
CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and
|
||||
correctness. Consequently, we enforce a constraint on the strides used by mapping functions
|
||||
such that:
|
||||
|
||||
1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in
|
||||
that rank are consecutive in linear memory.
|
||||
|
||||
2. The fastest changing rank is always last in the stride vector and not explicitly stored.
|
||||
|
||||
Thus, the stride vector used by mapping functions has length of one fewer than the rank of the
|
||||
storage tensor. These constraints are consistent with the BLAS interface of passing matrices as
|
||||
a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors
|
||||
whose fastest changing dimension is 1, and the stride vector is of length 1.
|
||||
|
||||
|
||||
A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor,
|
||||
to linear memory such that (1.) elements in the same column are consecutive in memory
|
||||
(column-major), or (2.) elements in the same row are consecutive (row-major). These can be
|
||||
accomplished by two different mapping functions whose stride vector is length=2. The first
|
||||
element is the "leading dimension."
|
||||
|
||||
The following mapping functions demonstrates mappings for these canonical matrix layouts. In
|
||||
both cases, the logical index space is referenced by coordinates of the form (row, column).
|
||||
|
||||
// cutlass/matrix_traits.h
|
||||
struct MatrixLayout {
|
||||
|
||||
//
|
||||
// TensorRefMapFunc definitions for common layouts
|
||||
//
|
||||
|
||||
/// Mapping function for row-major matrices
|
||||
struct RowMajor {
|
||||
|
||||
/// Storage rank = 2 implies stride vector: (ldm, 1)
|
||||
static int const kStorageRank = 2;
|
||||
|
||||
/// Maps (row, col) to (row, col)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(Coord<2> const &coord) const {
|
||||
return coord;
|
||||
}
|
||||
};
|
||||
|
||||
/// Mapping function for column-major matrices
|
||||
struct ColumnMajor {
|
||||
|
||||
/// Storage rank = 2 implies stride vector: (ldm, 1)
|
||||
static int const kStorageRank = 2;
|
||||
|
||||
/// Maps (row, col) to (col, row)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(Coord<2> const &coord) const {
|
||||
return make_Coord(coord[1], coord[0]);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
The requirement that the fastest-changing stride always be of unit size need not be a limitation.
|
||||
To implement "sparse" computations or matrix operations in which matrix elements have arbitrary
|
||||
stride along both row and column, define a mapping function whose storage rank is 3. This permits
|
||||
two elements of the stride vector to have a non-unit value. The map function defined in
|
||||
`cutlass::MatrixTraits::ContiguousLayout` is an example.
|
||||
|
||||
```
|
||||
/// Mapping function for scenario in which layout is row-major or column-major but this information
|
||||
/// is only available at runtime.
|
||||
struct ContiguousLayout {
|
||||
/// Arbitrary storage rank
|
||||
static int const kStorageRank = 3;
|
||||
|
||||
/// Dimension of rows
|
||||
static int const kRow = 0;
|
||||
|
||||
/// Dimension of columns
|
||||
static int const kColumn = 1;
|
||||
|
||||
/// Mapping function defined by runtime variable. Returns coordinates in n-D storage array
|
||||
/// as (matrix row, matrix colum, 0)
|
||||
CUTLASS_HOST_DEVICE
|
||||
Coord<kStorageRank> operator()(MatrixCoord const &coord) const {
|
||||
return make_Coord(coord.row(), coord.column(), 0);
|
||||
}
|
||||
|
||||
/// Helper to construct a stride vector based on contiguous matrix layout and leading dimension
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Coord<kStorageRank> stride(MatrixLayout::Kind layout, int ldm) {
|
||||
if (layout == MatrixLayout::kRowMajor) {
|
||||
return make_Coord(ldm, 1, 1);
|
||||
}
|
||||
return make_Coord(1, ldm, 1);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
cutlass::TensorView<> extends this concept by including a size vector to specify the bounds of
|
||||
the index space. The value of each coordinate in the size vector defines the half-open range of
|
||||
indices whose smallest value is zero.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// CUTLASS includes
|
||||
//
|
||||
|
||||
// Defines cutlass::Coord<>
|
||||
#include "cutlass/coord.h"
|
||||
|
||||
// Defines cutlass::TensorRef<>
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
// Defines cutlass::TensorView<>
|
||||
#include "cutlass/tensor_view.h"
|
||||
|
||||
// Defines cutlass::MatrixLayout
|
||||
#include "cutlass/matrix_traits.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Column-major matrix access
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Define a rank=2 tensor modeling a column-major matrix
|
||||
typedef cutlass::TensorView<
|
||||
int, // storage element is of type int
|
||||
2, // tensor has rank=2 logical index space
|
||||
cutlass::MatrixLayout::ColumnMajor // column-major mapping function
|
||||
> TensorViewColumnMajor;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Kernel to copy a matrix from raw memory into a cutlass::TensorView
|
||||
__global__ void MatrixCopyColumnMajor(
|
||||
TensorViewColumnMajor destination, // destination tensor accessed by TensorView
|
||||
int const *source, // source matrix accessed using cuBLAS-style pointer
|
||||
int ldm) { // and leading dimension
|
||||
|
||||
// Compute unique row and column for each thread
|
||||
int row = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int column = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
|
||||
// Define a coordinate based on the thread's row and column
|
||||
cutlass::Coord<2> coord = cutlass::make_Coord(row, column);
|
||||
|
||||
// Bounds test
|
||||
if (coord < destination.size()) {
|
||||
|
||||
// Access the element
|
||||
destination.at(coord) = source[row + column * ldm];
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Launches kernel MatrixCopyColumnMajor()
|
||||
cudaError_t TestMatrixCopyColumnMajor() {
|
||||
cudaError_t result;
|
||||
|
||||
int const M = 32; // number of rows
|
||||
int const N = 16; // number of columns
|
||||
|
||||
int const ldm = 40; // matrix leading dimension
|
||||
|
||||
//
|
||||
// Allocate source and destination matrices
|
||||
//
|
||||
|
||||
int *Destination;
|
||||
int *Source;
|
||||
|
||||
int const matrix_capacity = ldm * N; // number of elements in memory needed to store matrix
|
||||
size_t const sizeof_matrix = sizeof(int) * matrix_capacity; // size of matrix in bytes
|
||||
|
||||
// Allocate destination and source matrices
|
||||
result = cudaMalloc((void **)&Destination, sizeof_matrix);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to allocate destination matrix on device: " << cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
result = cudaMalloc((void **)&Source, sizeof_matrix);
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(Destination);
|
||||
std::cerr << "Failed to allocate source matrix on device:" << cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Clear destination matrix in device memory
|
||||
result = cudaMemset(Destination, 0, sizeof_matrix);
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(Destination);
|
||||
cudaFree(Source);
|
||||
std::cerr << "Failed to clear destination matrix: " << cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize matrix
|
||||
//
|
||||
|
||||
std::vector<int> source_host(matrix_capacity, 0);
|
||||
|
||||
// Procedurally generate input results using several arbitrary constants.
|
||||
int const magic_row_stride = 2;
|
||||
int const magic_column_stride = 3;
|
||||
|
||||
for (int j = 0; j < N; ++j) {
|
||||
for (int i = 0; i < M; ++i) {
|
||||
source_host.at(i + j * ldm) = i * magic_row_stride + j * magic_column_stride;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy to device memory
|
||||
result = cudaMemcpy(Source, source_host.data(), sizeof_matrix, cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
cudaFree(Destination);
|
||||
cudaFree(Source);
|
||||
std::cerr << "Failed to copy from host to source matrix: " << cudaGetErrorString(result) << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Define a TensorView<> pointing to the destination matrix
|
||||
//
|
||||
TensorViewColumnMajor destination_view_device(
|
||||
Destination, // pointer to base of matrix in device memory
|
||||
cutlass::make_Coord(ldm, 1), // stride vector
|
||||
cutlass::make_Coord(M, N) // bounds of matrix
|
||||
);
|
||||
|
||||
//
|
||||
// Launch kernel to copy matrix
|
||||
//
|
||||
|
||||
dim3 block(16, 16);
|
||||
dim3 grid((M + block.x - 1) / block.x, (N + block.y - 1) / block.y);
|
||||
|
||||
MatrixCopyColumnMajor<<< grid, block >>>(destination_view_device, Source, ldm);
|
||||
|
||||
result = cudaGetLastError();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Kernel MatrixCopyColumnMajor() failed: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(Destination);
|
||||
cudaFree(Source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Copy results to host memory
|
||||
//
|
||||
|
||||
std::vector<int> dest_host(matrix_capacity, 0);
|
||||
|
||||
result = cudaMemcpy(dest_host.data(), Destination, sizeof_matrix, cudaMemcpyDeviceToHost);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to copy destination matrix to host memory: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
|
||||
cudaFree(Destination);
|
||||
cudaFree(Source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify result
|
||||
//
|
||||
|
||||
// Define a TensorView for use in accessing host memory
|
||||
TensorViewColumnMajor destination_view_host(
|
||||
dest_host.data(), // pointer to base of matrix in host memory
|
||||
cutlass::make_Coord(ldm, 1), // stride vector
|
||||
cutlass::make_Coord(M, N) // bounds of matrix
|
||||
);
|
||||
|
||||
// Verify against procedurally computed results
|
||||
for (int j = 0; j < N; ++j) {
|
||||
for (int i = 0; i < M; ++i) {
|
||||
|
||||
// computed result
|
||||
int expected = i * magic_row_stride + j * magic_column_stride;
|
||||
|
||||
// access data by computing explicit offsets
|
||||
int got_explicit = dest_host.at(i + j * ldm);
|
||||
|
||||
// access data in host memory through a TensorView
|
||||
int got_view = destination_view_host.at(cutlass::make_Coord(i, j));
|
||||
|
||||
if (got_explicit != expected) {
|
||||
|
||||
std::cerr << "Error at element (" << i << ", " << j
|
||||
<< ") accessed through explicitly computed offset - expected: " << expected
|
||||
<< ", got: " << got_explicit << std::endl;
|
||||
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
if (got_view != expected) {
|
||||
|
||||
std::cerr << "Error at element (" << i << ", " << j
|
||||
<< ") accesed through TensorView<> on the host - expected: " << expected
|
||||
<< ", got: " << got_view << std::endl;
|
||||
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point for tensor_view example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 02_tensor_view
|
||||
//
|
||||
int main() {
|
||||
|
||||
cudaError_t result = TestMatrixCopyColumnMajor();
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed" << std::endl;
|
||||
}
|
||||
|
||||
return (result == cudaSuccess ? 0 : -1);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
38
examples/02_cutlass_utilities/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_CUTLASS_UTILITIES_SOURCES
|
||||
cutlass_utilities.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
02_cutlass_utilities
|
||||
${EXAMPLES_CUTLASS_UTILITIES_SOURCES}
|
||||
)
|
||||
359
examples/02_cutlass_utilities/cutlass_utilities.cu
Normal file
@ -0,0 +1,359 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates several CUTLASS utilities in the context of a mixed-precision
|
||||
floating-point matrix product computation.
|
||||
|
||||
These utilities are intended to be useful supporting components for managing tensor and matrix
|
||||
memory allocations, initializing and comparing results, and computing reference output.
|
||||
|
||||
CUTLASS utilities are defined in the directory `tools/util`, and definitions appear
|
||||
namespace `cutlass::` or an inner namespace therein. Operations in `cutlass::reference::` have
|
||||
both host-side and device-side implementations, and the choice to use device-side initialization
|
||||
and host-side verification in this example was arbitrary.
|
||||
|
||||
|
||||
cutlass::half_t
|
||||
|
||||
This is a host-only implementation of a half-precision floating-point type. It requires no
|
||||
specialized hardware support from the CPU and emulates arithmetic operations. Device-side code
|
||||
should use CUDA's `half` type.
|
||||
|
||||
|
||||
cutlass::HostMatrix<>
|
||||
|
||||
This template class simplifies the creation of a rank=2 tensor with either a column-major or
|
||||
row-major layout in memory.
|
||||
|
||||
This class offers methods device_view() and host_view() to provide TensorView objects for
|
||||
device- and host-side memory allocations.
|
||||
|
||||
|
||||
cutlass::reference::device::TensorInitialize()
|
||||
|
||||
This template function initializes the elements of a tensor according to either a procedural
|
||||
definition or a random distribution. The function in namespace `cutlass::reference::device::`
|
||||
uses a CUDA kernel to perform this initialization, relying on CURAND to compute random numbers.
|
||||
|
||||
|
||||
cutlass::reference::host::Gemm()
|
||||
|
||||
This template function computes the general matrix product. This template supports unique
|
||||
data types for each matrix operand, the internal accumulation type, and the scalar parameters
|
||||
alpha and beta.
|
||||
|
||||
|
||||
cutlass::reference::host::TensorEquals()
|
||||
|
||||
Compares two tensors of identical rank and returns true if values are bit equivalent.
|
||||
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS includes needed for mixed-precision GEMM kernel
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/fp16_sgemm_traits.h"
|
||||
|
||||
//
|
||||
// CUTLASS utility includes
|
||||
//
|
||||
|
||||
// Defines operator<<() to write TensorView objects to std::ostream
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
|
||||
// Defines cutlass::HostMatrix<>
|
||||
#include "tools/util/host_matrix.h"
|
||||
|
||||
// Defines cutlass::half_t
|
||||
#include "tools/util/half.h"
|
||||
|
||||
// Defines cutlass::reference::device::TensorInitialize()
|
||||
#include "tools/util/reference/device/tensor_elementwise.h"
|
||||
|
||||
// Defines cutlass::reference::host::TensorEquals()
|
||||
#include "tools/util/reference/host/tensor_elementwise.h"
|
||||
|
||||
// Defines cutlass::reference::host::Gemm()
|
||||
#include "tools/util/reference/host/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
|
||||
cudaError_t Cutlass_FP16_SgemmNN(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
cutlass::half_t alpha,
|
||||
half const *A,
|
||||
int lda,
|
||||
half const *B,
|
||||
int ldb,
|
||||
cutlass::half_t beta,
|
||||
half *C,
|
||||
int ldc) {
|
||||
|
||||
// Define a CUTLASS Gemm using mixed-precision floating-point.
|
||||
//
|
||||
// A, B, C, D are half-precision. Internal accumulation is in single-precision.
|
||||
//
|
||||
// Note, we use CUDA's `half` type for device-side code including CUTLASS GEMM kernels.
|
||||
//
|
||||
typedef cutlass::gemm::Fp16SgemmSgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::Shape<16, 128, 128>,
|
||||
half, // A type
|
||||
half, // B type
|
||||
half, // C type
|
||||
half, // D type
|
||||
half // Scalar type: alpha, beta
|
||||
>
|
||||
GemmTraits;
|
||||
|
||||
// Define a CUTLASS GEMM object.
|
||||
typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
|
||||
|
||||
// Construct and initialize CUTLASS GEMM parameters object.
|
||||
typename Gemm::Params params;
|
||||
|
||||
int result = params.initialize(
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
half(float(alpha)), // scalar alpha - This is a legal conversion from cutlass::half_t to CUDA's half.
|
||||
A, // matrix A operand
|
||||
lda,
|
||||
B, // matrix B operand
|
||||
ldb,
|
||||
half(float(beta)), // scalar beta - This is a legal conversion from cutlass::half_t to CUDA's half.
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc
|
||||
);
|
||||
|
||||
if (result) {
|
||||
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
Gemm::launch(params);
|
||||
|
||||
// Return any errors associated with the launch or cudaSuccess if no error.
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocate several matrices in GPU device memory and call a single-precision
|
||||
/// CUTLASS GEMM kernel.
|
||||
cudaError_t TestCutlassGemm(int M, int N, int K, cutlass::half_t alpha, cutlass::half_t beta) {
|
||||
cudaError_t result;
|
||||
|
||||
//
|
||||
// Construct cutlass::HostMatrix<> using the half-precision host-side type.
|
||||
//
|
||||
// cutlass::HostMatrix<> allocates memory on both the host and device corresponding to rank=2
|
||||
// tensors in column-major layout. Explicit synchronization methods are offered to copy the
|
||||
// tensor to the device or to the host.
|
||||
//
|
||||
|
||||
// M-by-K matrix of cutlass::half_t
|
||||
cutlass::HostMatrix<cutlass::half_t> A(cutlass::MatrixCoord(M, K));
|
||||
|
||||
// K-by-N matrix of cutlass::half_t
|
||||
cutlass::HostMatrix<cutlass::half_t> B(cutlass::MatrixCoord(K, N));
|
||||
|
||||
// M-by-N matrix of cutlass::half_t
|
||||
cutlass::HostMatrix<cutlass::half_t> C_cutlass(cutlass::MatrixCoord(M, N));
|
||||
|
||||
// M-by-N matrix of cutlass::half_t
|
||||
cutlass::HostMatrix<cutlass::half_t> C_reference(cutlass::MatrixCoord(M, N));
|
||||
|
||||
//
|
||||
// Initialize matrices with small, random integers.
|
||||
//
|
||||
|
||||
cutlass::Distribution dist;
|
||||
|
||||
// Uniform random distribution from -4 .. 4. Values are truncated to integers.
|
||||
dist.set_uniform(-4, 4);
|
||||
|
||||
// Arbitrary RNG seed value. Hard-coded for deterministic results.
|
||||
int seed = 2080;
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
A.device_view(), // concept: TensorView
|
||||
seed,
|
||||
dist);
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
B.device_view(), // concept: TensorView
|
||||
seed * 2,
|
||||
dist);
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
C_cutlass.device_view(), // concept: TensorView
|
||||
seed * 3,
|
||||
dist);
|
||||
|
||||
// Copy C_cutlass into C_reference so the GEMM is correct when beta != 0.
|
||||
cutlass::reference::device::TensorFill(C_reference.device_view(), C_cutlass.device_view());
|
||||
|
||||
// Copy the device-side view into host memory
|
||||
C_reference.sync_host();
|
||||
|
||||
//
|
||||
// Launch the CUTLASS GEMM kernel
|
||||
//
|
||||
|
||||
result = Cutlass_FP16_SgemmNN(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
A.device_data(),
|
||||
A.leading_dim(),
|
||||
B.device_data(),
|
||||
B.leading_dim(),
|
||||
beta,
|
||||
C_cutlass.device_data(),
|
||||
C_cutlass.leading_dim()
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify the result using a host-side reference
|
||||
//
|
||||
|
||||
// A and B were initialized using device-side procedures. The intent of this example is to
|
||||
// use the host-side reference GEMM, so we must perform a device-to-host copy.
|
||||
A.sync_host();
|
||||
B.sync_host();
|
||||
|
||||
// Copy CUTLASS's GEMM results into host memory.
|
||||
C_cutlass.sync_host();
|
||||
|
||||
// Compute the reference result using the host-side GEMM reference implementation.
|
||||
cutlass::reference::host::Gemm(
|
||||
cutlass::gemm::GemmCoord(K, N, M), // problem size (type: cutlass::gemm::GemmCoord)
|
||||
alpha, // alpha (type: cutlass::half_t)
|
||||
A.host_ref(), // A (concept: TensorRef)
|
||||
B.host_ref(), // B (concept: TensorRef)
|
||||
beta, // beta (type: cutlass::half_t)
|
||||
C_reference.host_ref(), // C (concept: TensorRef)
|
||||
float(0) // Accumulator initial value passed as argument to deduce
|
||||
); // internal accumulation data type as float.
|
||||
|
||||
// Compare reference to computed results.
|
||||
if (!cutlass::reference::host::TensorEquals(C_reference.host_view(), C_cutlass.host_view())) {
|
||||
|
||||
std::cerr << "Error - CUTLASS mixed-precision GEMM kernel differs from reference." << std::endl;
|
||||
|
||||
//
|
||||
// On error, print C_cutlass and C_reference to std::cerr.
|
||||
//
|
||||
// Note, these are matrices of half-precision elements stored in host memory as
|
||||
// arrays of type cutlass::half_t.
|
||||
//
|
||||
|
||||
// Result of CUTLASS mixed-precision GEMM kernel
|
||||
std::cerr << "CUTLASS:\n" << C_cutlass << std::endl;
|
||||
|
||||
// Result of reference computation
|
||||
std::cerr << "Reference:\n" << C_reference << std::endl;
|
||||
|
||||
// Return error code.
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
// Passed error check
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to cutlass_utilities example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 01_cutlass_utilities <M> <N> <K> <alpha> <beta>
|
||||
//
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
//
|
||||
// Parse the command line to obtain GEMM dimensions and scalar values.
|
||||
//
|
||||
|
||||
// GEMM problem dimensions: <M> <N> <K>
|
||||
int problem[3] = { 128, 128, 128 };
|
||||
|
||||
for (int i = 1; i < argc && i < 4; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
ss >> problem[i - 1];
|
||||
}
|
||||
|
||||
// Linear scale factors in GEMM. Note, these are half-precision values stored as
|
||||
// cutlass::half_t.
|
||||
//
|
||||
// Values outside the range of IEEE FP16 will overflow to infinity or underflow to zero.
|
||||
//
|
||||
cutlass::half_t scalars[2] = { 1, 0 };
|
||||
|
||||
for (int i = 4; i < argc && i < 6; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
|
||||
ss >> scalars[i - 4]; // lexical cast to cutlass::half_t
|
||||
}
|
||||
|
||||
//
|
||||
// Run the CUTLASS GEMM test.
|
||||
//
|
||||
|
||||
cudaError_t result = TestCutlassGemm(
|
||||
problem[0], // GEMM M dimension
|
||||
problem[1], // GEMM N dimension
|
||||
problem[2], // GEMM K dimension
|
||||
scalars[0], // alpha
|
||||
scalars[1] // beta
|
||||
);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
38
examples/03_strided_batched_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_STRIDED_BATCHED_GEMM_SOURCES
|
||||
strided_batched_gemm.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
03_strided_batched_gemm
|
||||
${EXAMPLES_STRIDED_BATCHED_GEMM_SOURCES}
|
||||
)
|
||||
349
examples/03_strided_batched_gemm/strided_batched_gemm.cu
Normal file
@ -0,0 +1,349 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/sgemm_traits.h"
|
||||
|
||||
/*
|
||||
This example demonstrates how to use cutlass to compute a batched strided gemm.
|
||||
In this example, both A and B matrix are non-transpose and column major matrix
|
||||
batched_C = batched_A x batched_B
|
||||
As an example, matrix C can be seen as
|
||||
-----------------------------------------------------------
|
||||
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
|
||||
-----------------------------------------------------------
|
||||
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
|
||||
-----------------------------------------------------------
|
||||
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
|
||||
-----------------------------------------------------------
|
||||
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
|
||||
-----------------------------------------------------------
|
||||
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
|
||||
-----------------------------------------------------------
|
||||
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
|
||||
-----------------------------------------------------------
|
||||
batch 0 | batch 1
|
||||
where we denote each element with (batch_idx, row_idx, column_idx)
|
||||
In this example, batch size is 2, M is 6 and N is 3
|
||||
The stride (batch_stride_C) between the first element of two batches is ldc * n
|
||||
|
||||
matrix A can be seen as
|
||||
---------------------------------------
|
||||
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
|
||||
---------------------------------------
|
||||
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
|
||||
---------------------------------------
|
||||
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
|
||||
---------------------------------------
|
||||
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
|
||||
---------------------------------------
|
||||
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
|
||||
---------------------------------------
|
||||
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
|
||||
---------------------------------------
|
||||
batch 0 | batch 1
|
||||
, where batch size is 2, M is 6 and K is 2
|
||||
The stride (batch_stride_B) between the first element of two batches is lda * k
|
||||
|
||||
matrix B can be seen as
|
||||
-----------------------------
|
||||
(0,0,0) | (0,0,1) | (0,0,2) |
|
||||
----------------------------- batch 0
|
||||
(0,1,0) | (0,1,1) | (0,1,2) |
|
||||
-------------------------------------
|
||||
(1,0,0) | (1,0,1) | (1,0,2) |
|
||||
----------------------------- batch 1
|
||||
(1,1,0) | (1,1,1) | (1,1,2) |
|
||||
-----------------------------
|
||||
, where the batch size is 2, N is 3 and K is 2
|
||||
The stride (batch_stride_C) between the first element of two batches is k
|
||||
|
||||
|
||||
*/
|
||||
|
||||
cudaError_t cutlass_strided_batched_sgemm(float const *A,
|
||||
int lda,
|
||||
long long int batch_stride_A,
|
||||
float const *B,
|
||||
int ldb,
|
||||
long long int batch_stride_B,
|
||||
float *C,
|
||||
int ldc,
|
||||
long long int batch_stride_C,
|
||||
float alpha,
|
||||
float beta,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int batch_count) {
|
||||
// create a cutlass traits
|
||||
typedef cutlass::gemm::SgemmTraits<cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor, cutlass::Shape<8, 128, 128> >
|
||||
SgemmTraits;
|
||||
|
||||
// create a CUTLASS GEMM object.
|
||||
typedef cutlass::gemm::Gemm<SgemmTraits> Gemm;
|
||||
|
||||
// Construct and initialize CUTLASS GEMM parameters object.
|
||||
typename Gemm::Params params;
|
||||
|
||||
int result = params.initialize(
|
||||
m, // M dimension for each batch
|
||||
n, // N dimension for each batch
|
||||
k, // K dimension for each batch
|
||||
alpha, // scalar alpha
|
||||
A,
|
||||
lda,
|
||||
batch_stride_A, // distance in memory between the first element of neighboring batch
|
||||
B,
|
||||
ldb,
|
||||
batch_stride_B, // distance in memory between the first element of neighboring batch
|
||||
beta, // scalar beta
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
batch_stride_C, // distance in memory between the first element of neighboring batch
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc,
|
||||
batch_stride_C, // distance in memory between the first element of neighboring batch
|
||||
batch_count
|
||||
);
|
||||
|
||||
if (result != 0) {
|
||||
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
Gemm::launch(params);
|
||||
result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "kernel launch result = " << result << std::endl;
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
cudaError_t strided_batched_gemm_nn_reference(std::vector<T> const &A,
|
||||
int lda,
|
||||
long long int batch_stride_A,
|
||||
std::vector<T> const &B,
|
||||
int ldb,
|
||||
long long int batch_stride_B,
|
||||
std::vector<T> &C,
|
||||
int ldc,
|
||||
long long int batch_stride_C,
|
||||
T alpha,
|
||||
T beta,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int batch_count) {
|
||||
/*
|
||||
strided batched gemm NN
|
||||
*/
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
if (A.size() < lda * k * batch_count) {
|
||||
std::cout << "the size of A is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (B.size() < ldb * n) {
|
||||
std::cout << "the size of B is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
if (C.size() < ldc * n * batch_count) {
|
||||
std::cout << "the size of C is too small" << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) {
|
||||
for (int n_idx = 0; n_idx < n; n_idx++) {
|
||||
for (int m_idx = 0; m_idx < m; m_idx++) {
|
||||
T accum = beta * C[batch_idx * batch_stride_C + n_idx * ldc + m_idx];
|
||||
for (int k_idx = 0; k_idx < k; k_idx++) {
|
||||
accum += alpha
|
||||
* A[batch_idx * batch_stride_A + k_idx * lda + m_idx]
|
||||
* B[batch_idx * batch_stride_B + n_idx * ldb + k_idx];
|
||||
}
|
||||
C[batch_idx * batch_stride_C + n_idx * ldc + m_idx] = accum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int main() {
|
||||
int const m = 16;
|
||||
int const n = 24;
|
||||
int const k = 8;
|
||||
int const batch_count = 3;
|
||||
|
||||
// A, B are non-transpose, column major
|
||||
int const lda = m;
|
||||
int const ldb = k * batch_count;
|
||||
int const ldc = m;
|
||||
|
||||
int const count_A = batch_count * lda * k;
|
||||
int const count_B = ldb * n;
|
||||
int const count_C = batch_count * ldc * n;
|
||||
|
||||
// the memory is batched along K dimension
|
||||
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
|
||||
long long int batch_stride_B = static_cast<long long int>(k);
|
||||
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
|
||||
|
||||
// alpha and beta
|
||||
float alpha = 1.0f;
|
||||
float beta = 2.0f;
|
||||
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
// allocate the host memory
|
||||
std::vector<float> host_A(count_A);
|
||||
std::vector<float> host_B(count_B);
|
||||
std::vector<float> host_C(count_C);
|
||||
std::vector<float> result_C(count_C);
|
||||
|
||||
// allocate the device memory
|
||||
float *A;
|
||||
float *B;
|
||||
float *C;
|
||||
|
||||
result = cudaMalloc(&A, count_A * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&B, count_B * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMalloc(&C, count_C * sizeof(float));
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMalloc result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// fill A
|
||||
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
for (int col_idx = 0; col_idx < k; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < m; row_idx++) {
|
||||
host_A[row_idx + col_idx * lda + b_idx * lda * k] = static_cast<float>(row_idx + col_idx * lda + b_idx * lda * k);
|
||||
}
|
||||
}
|
||||
}
|
||||
// fill B
|
||||
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
for (int col_idx = 0; col_idx < n; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < k; row_idx++) {
|
||||
host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<float>(n + k * ldb + batch_count * k) - static_cast<float>(row_idx + col_idx * ldb + b_idx * k);
|
||||
}
|
||||
}
|
||||
}
|
||||
// fill C
|
||||
for (int b_idx = 0; b_idx < batch_count; b_idx++) {
|
||||
for (int col_idx = 0; col_idx < n; col_idx++) {
|
||||
for (int row_idx = 0; row_idx < m; row_idx++) {
|
||||
host_C[row_idx + col_idx * ldc + b_idx * ldc * n] = 1.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ref memory
|
||||
std::vector<float> ref_A(host_A);
|
||||
std::vector<float> ref_B(host_B);
|
||||
std::vector<float> ref_C(host_C);
|
||||
// copy host memory to device
|
||||
result = cudaMemcpy(A, host_A.data(), count_A * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(B, host_B.data(), count_B * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaMemcpy(C, host_C.data(), count_C * sizeof(float), cudaMemcpyHostToDevice);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
// run cutlass
|
||||
result = cutlass_strided_batched_sgemm(A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
|
||||
alpha, beta, m, n, k, batch_count);
|
||||
if (result != cudaSuccess)
|
||||
return result;
|
||||
|
||||
// copy device memory to host
|
||||
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaMemcpy result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
//compare with reference code
|
||||
result = strided_batched_gemm_nn_reference(ref_A, lda, batch_stride_A, ref_B, ldb, batch_stride_B, ref_C, ldc, batch_stride_C,
|
||||
alpha, beta, m, n, k, batch_count);
|
||||
if (result != 0)
|
||||
return result;
|
||||
|
||||
if (ref_C != result_C) {
|
||||
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
// free memory
|
||||
result = cudaFree(A);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaFree(B);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
result = cudaFree(C);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "cudaFree result = " << result << std::endl;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
38
examples/04_tile_iterator/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES
|
||||
tile_iterator.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
04_tile_iterator
|
||||
${EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES}
|
||||
)
|
||||
248
examples/04_tile_iterator/tile_iterator.cu
Normal file
@ -0,0 +1,248 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates how to use the TileIterator in CUTLASS to load data from addressable
|
||||
memory, and store it back into addressable memory.
|
||||
|
||||
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data from
|
||||
and to addressable memory. The TileIterator accepts a TileTraits type, which defines the shape of a
|
||||
tile and the distribution of accesses by individual entities, either threads or others.
|
||||
|
||||
In this example, a LoadTileIterator is used to load elements from a tile in global memory, stored in
|
||||
column-major layout, into a fragment, and a corresponding StoreTileIterator is used to store the
|
||||
elements back into global memory (in the same column-major layout).
|
||||
|
||||
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
|
||||
|
||||
This example uses CUTLASS utilities to ease the matrix operations.
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS includes
|
||||
#include "cutlass/tile_iterator.h"
|
||||
#include "cutlass/tile_traits_standard.h"
|
||||
|
||||
//
|
||||
// CUTLASS utility includes
|
||||
//
|
||||
|
||||
// Defines operator<<() to write TensorView objects to std::ostream
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
|
||||
// Defines cutlass::HostMatrix<>
|
||||
#include "tools/util/host_matrix.h"
|
||||
|
||||
// Defines cutlass::reference::device::TensorInitialize()
|
||||
#include "tools/util/reference/device/tensor_elementwise.h"
|
||||
|
||||
// Defines cutlass::reference::host::TensorEquals()
|
||||
#include "tools/util/reference/host/tensor_elementwise.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// This function defines load and store tile iterators to load and store a M-by-K tile, in
|
||||
// column-major layout, from and back into global memory.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Traits>
|
||||
__global__ void cutlass_tile_iterator_load_store_global(
|
||||
float const *input,
|
||||
float *output,
|
||||
int M,
|
||||
int K) {
|
||||
|
||||
// Define a tile load iterator
|
||||
typedef cutlass::TileLoadIterator<
|
||||
Traits, // the Traits type, defines shape/distribution of accesses
|
||||
float, // elements are of type float
|
||||
cutlass::IteratorAdvance::kH, // post-increment accesses advance in strided (as opposed to
|
||||
// contiguous dimension
|
||||
cutlass::MemorySpace::kGlobal // iterator loads from global memory
|
||||
> TileLoadIterator;
|
||||
|
||||
// Defines a tile store iterator
|
||||
typedef cutlass::TileStoreIterator<
|
||||
Traits, // the Traits type, defines shape/distribution of accesses
|
||||
float, // elements are of type float
|
||||
cutlass::IteratorAdvance::kH, // post-increment accesses advance in strided (as opposed to
|
||||
// contiguous) dimension
|
||||
cutlass::MemorySpace::kGlobal // iterator stores into global memory
|
||||
> TileStoreIterator;
|
||||
|
||||
// Defines a predicate vector for managing statically sized vector of boolean predicates
|
||||
typedef typename TileLoadIterator::PredicateVector PredicateVector;
|
||||
|
||||
// The parameters specified to the iterators. These include the pointer to the source of
|
||||
// addressable memory, and the strides and increments for each of the tile's dimensions
|
||||
typename TileLoadIterator::Params load_params;
|
||||
typename TileStoreIterator::Params store_params;
|
||||
|
||||
// Initializing the parameters for both of the iterators. The TileLoadIterator accesses the
|
||||
// input matrix and TileStoreIterator accesses the output matrix. The strides are set
|
||||
// identically since the data is being stored in the same way as it is loaded (column-major
|
||||
// mapping).
|
||||
load_params.initialize(input, M*K, M, 1);
|
||||
store_params.initialize(output, M*K, M, 1);
|
||||
|
||||
// Constructing the tile load and store iterators, and the predicates vector
|
||||
TileLoadIterator load_iterator(load_params);
|
||||
TileStoreIterator store_iterator(store_params);
|
||||
PredicateVector predicates;
|
||||
|
||||
// Initializing the predicates with bounds set to <1, K, M>. This protects out-of-bounds loads.
|
||||
load_iterator.initialize_predicates(predicates.begin(), cutlass::make_Coord(1, K, M));
|
||||
|
||||
// The fragment in which the elements are loaded into and stored from.
|
||||
typename TileLoadIterator::Fragment fragment;
|
||||
|
||||
// Loading a tile into a fragment and advancing to the next tile's position
|
||||
load_iterator.load_post_increment(fragment, predicates.begin());
|
||||
// Storing a tile from fragment and advancing to the next tile's position
|
||||
store_iterator.store_post_increment(fragment);
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Launches cutlass_tile_iterator_load_store_global kernel
|
||||
cudaError_t test_cutlass_tile_iterator() {
|
||||
cudaError_t result = cudaSuccess;
|
||||
|
||||
// Creating a M-by-K (128-by-8) tile for this example.
|
||||
static int const M = 128;
|
||||
static int const K = 8;
|
||||
// The kernel is launched with 128 threads per thread block.
|
||||
static int const kThreadsPerThreadBlock = 128;
|
||||
// Define the tile type
|
||||
typedef cutlass::Shape<1, 8, 128> Tile;
|
||||
|
||||
// CUTLASS provides a standard TileTraits type, which chooses the 'best' shape to enable warp
|
||||
// raking along the contiguous dimension if possible.
|
||||
typedef cutlass::TileTraitsStandard<Tile, kThreadsPerThreadBlock> Traits;
|
||||
|
||||
// M-by-K input matrix of float
|
||||
cutlass::HostMatrix<float> input(cutlass::MatrixCoord(M, K));
|
||||
|
||||
// M-by-K output matrix of float
|
||||
cutlass::HostMatrix<float> output(cutlass::MatrixCoord(M, K));
|
||||
|
||||
//
|
||||
// Initialize input matrix with linear combination.
|
||||
//
|
||||
|
||||
cutlass::Distribution dist;
|
||||
|
||||
// Linear distribution in column-major format.
|
||||
dist.set_linear(1, 1, M);
|
||||
|
||||
// Arbitrary RNG seed value. Hard-coded for deterministic results.
|
||||
int seed = 2080;
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
input.device_view(), // concept: TensorView
|
||||
seed,
|
||||
dist);
|
||||
|
||||
// Initialize output matrix to all zeroes.
|
||||
output.fill(0);
|
||||
|
||||
// Launch kernel to load and store tiles from/to global memory.
|
||||
cutlass_tile_iterator_load_store_global<Traits><<<
|
||||
dim3(1, 1, 1),
|
||||
dim3(kThreadsPerThreadBlock, 1)
|
||||
>>>(input.device_data(), output.device_data(), M, K);
|
||||
|
||||
result = cudaDeviceSynchronize();
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Copy results to host
|
||||
output.sync_host();
|
||||
|
||||
// Verify results
|
||||
for(int i = 0; i < M; ++i) {
|
||||
for(int j = 0; j < K; ++j) {
|
||||
if(output.at(cutlass::make_Coord(i, j)) != float(M*j+i+1)){
|
||||
std::cout << "FAILED: (" << i << ", " << j
|
||||
<< ") -- expected: " << (M*j+i+1)
|
||||
<< ", actual: " << output.at(cutlass::make_Coord(i, j))
|
||||
<< std::endl;
|
||||
result = cudaErrorUnknown;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to tile_iterator example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 04_tile_iterator
|
||||
//
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
// Properties of CUDA device
|
||||
cudaDeviceProp device_properties;
|
||||
|
||||
// Assumne the device id is 0.
|
||||
int device_id = 0;
|
||||
|
||||
cudaError_t result = cudaGetDeviceProperties(&device_properties, device_id);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to get device properties: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Run the CUTLASS tile iterator test.
|
||||
//
|
||||
|
||||
result = test_cutlass_tile_iterator();
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
38
examples/05_wmma_gemm/CMakeLists.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
set(EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES
|
||||
wmma_gemm.cu
|
||||
)
|
||||
|
||||
if (NOT CUTLASS_NATIVE_CUDA)
|
||||
# cuda_add_executable does not take interface include directories into account
|
||||
# Let's fetch them and pass them to CUDA.
|
||||
get_target_property(CUTLASS_INCLUDES CUTLASS INTERFACE_INCLUDE_DIRECTORIES)
|
||||
include_directories("${CUTLASS_INCLUDES}")
|
||||
endif()
|
||||
|
||||
cutlass_add_executable(
|
||||
05_wmma_gemm
|
||||
${EXAMPLES_BASIC_CUTLASS_GEMM_SOURCES}
|
||||
)
|
||||
353
examples/05_wmma_gemm/wmma_gemm.cu
Normal file
@ -0,0 +1,353 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*
|
||||
This example demonstrates how to call a CUTLASS GEMM kernel using Turing integer WMMA.
|
||||
|
||||
The CUTLASS integer WMMA Gemm template is instantiated in the function Cutlass_S8_WmmagemmNN. This
|
||||
is kernel computes the general matrix product (GEMM) using integer arithmetic accelerated by Turing
|
||||
WMMA and assumes all matrices have column-major layout.
|
||||
|
||||
The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices.
|
||||
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
|
||||
in CUTLASS.
|
||||
|
||||
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
|
||||
|
||||
This example uses CUTLASS utilities to ease the matrix operations.
|
||||
*/
|
||||
|
||||
// Standard Library includes
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
// CUTLASS includes needed for WMMA GEMM kernel
|
||||
#include "cutlass/wmma_matrix.h"
|
||||
|
||||
// This example works only when this MACRO is defined in "cutlass/wmma_matrix.h"
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
|
||||
// Defines cutlass::gemm::Gemm, the generic Gemm computation template class.
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
// Defines cutlass::gemm::WmmaGemmTraits, the structural components for WMMA GEMM
|
||||
#include "cutlass/gemm/wmma_gemm_traits.h"
|
||||
|
||||
//
|
||||
// CUTLASS utility includes
|
||||
//
|
||||
|
||||
// Defines operator<<() to write TensorView objects to std::ostream
|
||||
#include "tools/util/tensor_view_io.h"
|
||||
|
||||
// Defines cutlass::HostMatrix<>
|
||||
#include "tools/util/host_matrix.h"
|
||||
|
||||
// Defines cutlass::reference::device::TensorInitialize()
|
||||
#include "tools/util/reference/device/tensor_elementwise.h"
|
||||
|
||||
// Defines cutlass::reference::host::TensorEquals()
|
||||
#include "tools/util/reference/host/tensor_elementwise.h"
|
||||
|
||||
// Defines cutlass::reference::host::Gemm()
|
||||
#include "tools/util/reference/host/gemm.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object,
|
||||
// and launches it on the CUDA device.
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
|
||||
cudaError_t Cutlass_S8_WmmagemmNN(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int alpha,
|
||||
signed char const *A,
|
||||
int lda,
|
||||
signed char const *B,
|
||||
int ldb,
|
||||
int beta,
|
||||
int *C,
|
||||
int ldc) {
|
||||
|
||||
// Define type definition for 8-bit signed int WMMA CUTLASS GEMM with column-major
|
||||
// input matrices and 128x128x128 threadblock tile size.
|
||||
//
|
||||
// Note, A and B are 8-bit signed int. C and D are 32-bit int. .
|
||||
//
|
||||
typedef cutlass::gemm::WmmaGemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of A matrix
|
||||
cutlass::MatrixLayout::kColumnMajor, // layout of B matrix
|
||||
cutlass::Shape<128, 128, 128>, // threadblock tile size
|
||||
signed char, // A type
|
||||
signed char, // B type
|
||||
int, // D type
|
||||
cutlass::gemm::LinearScaling<int>, // functor to do the math in the epilogue
|
||||
int, // accumulator type
|
||||
cutlass::Shape<128, 32, 32>, // warp tile size
|
||||
cutlass::Shape<16, 16, 16>, // WMMA instruction tile size
|
||||
16, // scalars every time a thread loads from A
|
||||
16 // scalars every time a thread loads from B
|
||||
>
|
||||
GemmTraits;
|
||||
|
||||
// Define a CUTLASS GEMM type from a GemmTraits<> instantiation.
|
||||
typedef cutlass::gemm::Gemm<GemmTraits> Gemm;
|
||||
|
||||
// Construct and initialize CUTLASS GEMM parameters object.
|
||||
typename Gemm::Params params;
|
||||
|
||||
int result = params.initialize(
|
||||
M, // GEMM M dimension
|
||||
N, // GEMM N dimension
|
||||
K, // GEMM K dimension
|
||||
alpha, // scalar alpha
|
||||
A, // matrix A operand
|
||||
lda,
|
||||
B, // matrix B operand
|
||||
ldb,
|
||||
beta, // scalar beta
|
||||
C, // source matrix C
|
||||
ldc,
|
||||
C, // destination matrix C (may be different memory than source C matrix)
|
||||
ldc
|
||||
);
|
||||
|
||||
if (result) {
|
||||
std::cerr << "Failed to initialize CUTLASS Gemm::Params object." << std::endl;
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
Gemm::launch(params);
|
||||
|
||||
// Return any errors associated with the launch or cudaSuccess if no error.
|
||||
return cudaGetLastError();
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Allocate several matrices in GPU device memory and call an integer
|
||||
/// CUTLASS WMMA GEMM kernel.
|
||||
cudaError_t TestCutlassGemm(int M, int N, int K, int alpha, int beta) {
|
||||
cudaError_t result;
|
||||
|
||||
//
|
||||
// Construct cutlass::HostMatrix<> using the integer host-side types.
|
||||
|
||||
// M-by-K matrix of signed char
|
||||
cutlass::HostMatrix<signed char> A(cutlass::MatrixCoord(M, K));
|
||||
|
||||
// K-by-N matrix of signed char
|
||||
cutlass::HostMatrix<signed char> B(cutlass::MatrixCoord(K, N));
|
||||
|
||||
// M-by-N matrix of int
|
||||
cutlass::HostMatrix<int> C_cutlass(cutlass::MatrixCoord(M, N));
|
||||
|
||||
// M-by-N matrix of int
|
||||
cutlass::HostMatrix<int> C_reference(cutlass::MatrixCoord(M, N));
|
||||
|
||||
//
|
||||
// Initialize matrices with small, random integers.
|
||||
//
|
||||
|
||||
cutlass::Distribution dist;
|
||||
|
||||
// Uniform random distribution from -4 .. 4. Values are truncated to integers.
|
||||
dist.set_uniform(-4, 4);
|
||||
|
||||
// Arbitrary RNG seed value. Hard-coded for deterministic results.
|
||||
int seed = 2080;
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
A.device_view(), // concept: TensorView
|
||||
seed,
|
||||
dist);
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
B.device_view(), // concept: TensorView
|
||||
seed * 2,
|
||||
dist);
|
||||
|
||||
cutlass::reference::device::TensorInitialize(
|
||||
C_cutlass.device_view(), // concept: TensorView
|
||||
seed * 3,
|
||||
dist);
|
||||
|
||||
// Copy C_cutlass into C_reference so the GEMM is correct when beta != 0.
|
||||
cutlass::reference::device::TensorFill(C_reference.device_view(), C_cutlass.device_view());
|
||||
|
||||
// Copy the device-side view into host memory
|
||||
C_reference.sync_host();
|
||||
|
||||
//
|
||||
// Launch the CUTLASS GEMM kernel
|
||||
//
|
||||
|
||||
result = Cutlass_S8_WmmagemmNN(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
alpha,
|
||||
A.device_data(),
|
||||
A.leading_dim(),
|
||||
B.device_data(),
|
||||
B.leading_dim(),
|
||||
beta,
|
||||
C_cutlass.device_data(),
|
||||
C_cutlass.leading_dim()
|
||||
);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//
|
||||
// Verify the result using a host-side reference
|
||||
//
|
||||
|
||||
// A and B were initialized using device-side procedures.
|
||||
A.sync_host();
|
||||
B.sync_host();
|
||||
|
||||
// Copy CUTLASS's GEMM results into host memory.
|
||||
C_cutlass.sync_host();
|
||||
|
||||
// Compute the reference result using the host-side GEMM reference implementation.
|
||||
cutlass::reference::host::Gemm(
|
||||
cutlass::gemm::GemmCoord(K, N, M), // problem size (type: cutlass::gemm::GemmCoord)
|
||||
alpha, // alpha (type: int)
|
||||
A.host_ref(), // A (concept: TensorRef)
|
||||
B.host_ref(), // B (concept: TensorRef)
|
||||
beta, // beta (int)
|
||||
C_reference.host_ref(), // C (concept: TensorRef)
|
||||
int(0) // Accumulator initial value passed as argument to deduce
|
||||
); // internal accumulation data type as int.
|
||||
|
||||
// Compare reference to computed results.
|
||||
if (!cutlass::reference::host::TensorEquals(C_reference.host_view(), C_cutlass.host_view())) {
|
||||
|
||||
std::cerr << "Error - CUTLASS WMMA GEMM kernel differs from reference." << std::endl;
|
||||
|
||||
//
|
||||
// On error, print C_cutlass and C_reference to std::cerr.
|
||||
//
|
||||
|
||||
// Result of CUTLASS WMMA GEMM kernel
|
||||
std::cerr << "CUTLASS:\n" << C_cutlass << std::endl;
|
||||
|
||||
// Result of reference computation
|
||||
std::cerr << "Reference:\n" << C_reference << std::endl;
|
||||
|
||||
// Return error code.
|
||||
return cudaErrorUnknown;
|
||||
}
|
||||
|
||||
// Passed error check
|
||||
return cudaSuccess;
|
||||
}
|
||||
#endif // defined CUTLASS_USE_SUBBYTE_WMMA
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Entry point to wmma_gemm example.
|
||||
//
|
||||
// usage:
|
||||
//
|
||||
// 05_wmma_gemm <M> <N> <K> <alpha> <beta>
|
||||
//
|
||||
int main(int argc, const char *arg[]) {
|
||||
|
||||
#ifdef CUTLASS_USE_SUBBYTE_WMMA
|
||||
// Properties of CUDA device
|
||||
cudaDeviceProp device_properties;
|
||||
|
||||
// Assumne the device id is 0.
|
||||
int device_id = 0;
|
||||
|
||||
cudaError_t result = cudaGetDeviceProperties(&device_properties, device_id);
|
||||
if (result != cudaSuccess) {
|
||||
std::cerr << "Failed to get device properties: "
|
||||
<< cudaGetErrorString(result) << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((device_properties.major * 10 + device_properties.minor) < 75) {
|
||||
std::cerr << "This example needs to run on a Turing device." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
//
|
||||
// Parse the command line to obtain GEMM dimensions and scalar values.
|
||||
//
|
||||
|
||||
// GEMM problem dimensions.
|
||||
int problem[3] = { 128, 128, 128 };
|
||||
|
||||
for (int i = 1; i < argc && i < 4; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
ss >> problem[i - 1];
|
||||
}
|
||||
|
||||
// Scalars used for linear scaling the result of the matrix product.
|
||||
int scalars[2] = { 1, 0 };
|
||||
|
||||
for (int i = 4; i < argc && i < 6; ++i) {
|
||||
std::stringstream ss(arg[i]);
|
||||
ss >> scalars[i - 4];
|
||||
}
|
||||
|
||||
//
|
||||
// Run the CUTLASS GEMM test.
|
||||
//
|
||||
|
||||
result = TestCutlassGemm(
|
||||
problem[0], // GEMM M dimension
|
||||
problem[1], // GEMM N dimension
|
||||
problem[2], // GEMM K dimension
|
||||
scalars[0], // alpha
|
||||
scalars[1] // beta
|
||||
);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
std::cout << "Passed." << std::endl;
|
||||
}
|
||||
|
||||
// Exit.
|
||||
return result == cudaSuccess ? 0 : -1;
|
||||
|
||||
#else
|
||||
std::cerr << "CUTLASS WMMA GEMM targeting Turing Tensor Cores features requires CUDA 10." << std::endl;
|
||||
return -1;
|
||||
#endif // defined CUTLASS_USE_SUBBYTE_WMMA
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
28
examples/CMakeLists.txt
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
# provided that the following conditions are met:
|
||||
# * Redistributions of source code must retain the above copyright notice, this list of
|
||||
# conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
# conditions and the following disclaimer in the documentation and/or other materials
|
||||
# provided with the distribution.
|
||||
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
# to endorse or promote products derived from this software without specific prior written
|
||||
# permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
add_subdirectory(00_basic_gemm)
|
||||
add_subdirectory(01_tensor_view)
|
||||
add_subdirectory(02_cutlass_utilities)
|
||||
add_subdirectory(03_strided_batched_gemm)
|
||||
add_subdirectory(04_tile_iterator)
|
||||
add_subdirectory(05_wmma_gemm)
|
||||
BIN
media/images/cutlass-threadblock-gemm.png
Normal file
|
After Width: | Height: | Size: 59 KiB |
BIN
media/images/cutlass-tile-iteration.png
Normal file
|
After Width: | Height: | Size: 75 KiB |
BIN
media/images/cutlass-tile-structure.png
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
media/images/cutlass-warp-thread-tile-structure.png
Normal file
|
After Width: | Height: | Size: 176 KiB |
|
Before Width: | Height: | Size: 251 KiB After Width: | Height: | Size: 253 KiB |
BIN
media/images/gemm-structural-components.png
Normal file
|
After Width: | Height: | Size: 240 KiB |
@ -34,12 +34,14 @@ set(CUTLASS_PERF_TEST_HEADERS
|
||||
)
|
||||
|
||||
set(CUTLASS_PERF_TEST_SOURCES
|
||||
cutlass_perf_test.cpp
|
||||
cutlass_perf_test.cu
|
||||
gemm/sgemm.cu
|
||||
gemm/dgemm.cu
|
||||
gemm/hgemm.cu
|
||||
gemm/igemm.cu
|
||||
gemm/wmma_gemm.cu
|
||||
gemm/wmma_binary_gemm.cu
|
||||
gemm/wmma_integer_gemm.cu
|
||||
)
|
||||
|
||||
source_group("Source\ Files" FILES ${CUTLASS_PERF_TEST_SOURCES})
|
||||
@ -56,4 +58,6 @@ cutlass_add_executable(
|
||||
${CUTLASS_PERF_TEST_SOURCES}
|
||||
${CUTLASS_PERF_TEST_HEADERS}
|
||||
)
|
||||
CUDA_ADD_CUBLAS_TO_TARGET(cutlass_perf_test)
|
||||
|
||||
target_link_libraries(cutlass_perf_test ${CUBLAS_LIBRARY})
|
||||
|
||||
|
||||
@ -27,19 +27,24 @@
|
||||
\brief CUTLASS Performance Tests
|
||||
*/
|
||||
|
||||
#include <tools/test/perf/testbench_options.h>
|
||||
#include <tools/test/perf/testbench_output.h>
|
||||
#include <vector>
|
||||
#include "tools/test/perf/performance_result.h"
|
||||
#include "tools/test/perf/testbench_configs.h"
|
||||
#include "tools/test/perf/testbench_options.h"
|
||||
#include "tools/test/perf/testbench_output.h"
|
||||
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
|
||||
static std::vector<perf::GemmProfileFunc*> GemmProfileFuncs;
|
||||
|
||||
//
|
||||
// Profiling entry points defined in corresponding .cu files
|
||||
//
|
||||
namespace perf {
|
||||
|
||||
int profile_sgemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
int profile_hgemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
int profile_igemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
int profile_wmma_gemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
void RegisterGemmProfileFunc(GemmProfileFunc * profileFunc) {
|
||||
GemmProfileFuncs.push_back(profileFunc);
|
||||
}
|
||||
|
||||
} // namespace perf
|
||||
|
||||
@ -47,6 +52,22 @@ int profile_wmma_gemm(TestbenchOutput &output, TestbenchOptions const &options);
|
||||
// Executes profiling functionality
|
||||
//
|
||||
|
||||
template <typename Problem>
|
||||
int profile(int (**functions)(perf::TestbenchOutput<Problem> &,
|
||||
perf::TestbenchOptions const &,
|
||||
perf::Config const &),
|
||||
perf::TestbenchOutput<Problem> &output,
|
||||
perf::TestbenchOptions options,
|
||||
int result) {
|
||||
perf::TestbenchConfigs test_configs(options);
|
||||
for (size_t j = 0; !result && j < test_configs.configs.size(); j++) {
|
||||
for (size_t i = 0; !result && functions[i] != 0; ++i) {
|
||||
result = (functions[i])(output, options, test_configs.configs[j]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Entry point to CUTLASS performance test
|
||||
int main(int argc, const char **argv) {
|
||||
cutlass::CommandLine args(argc, argv);
|
||||
@ -57,20 +78,17 @@ int main(int argc, const char **argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
perf::TestbenchOutput output(options);
|
||||
|
||||
int (*profile_gemm[])(perf::TestbenchOutput &, perf::TestbenchOptions const &) = {
|
||||
perf::profile_sgemm,
|
||||
perf::profile_dgemm,
|
||||
perf::profile_hgemm,
|
||||
perf::profile_igemm,
|
||||
perf::profile_wmma_gemm,
|
||||
0};
|
||||
|
||||
int result = 0;
|
||||
for (int i = 0; !result && profile_gemm[i]; ++i) {
|
||||
result = (profile_gemm[i])(output, options);
|
||||
if (args.check_cmd_line_flag("version")) {
|
||||
perf::TestbenchOptions::version(std::cout);
|
||||
std::cout << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
return result;
|
||||
int result = 0;
|
||||
|
||||
std::vector<perf::GemmProfileFunc*> profileFuncs = GemmProfileFuncs;
|
||||
profileFuncs.push_back(0); // Passing as array reference below, so need NULL termination.
|
||||
perf::TestbenchOutput<perf::GemmProblem> output_gemm(options);
|
||||
result = profile(&profileFuncs[0], output_gemm, options, result);
|
||||
return result;
|
||||
}
|
||||
44
tools/test/perf/cutlass_perf_test.h
Normal file
@ -0,0 +1,44 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#pragma diag_suppress boolean_controlling_expr_is_constant
|
||||
#include <gtest/gtest.h>
|
||||
#pragma diag_warning boolean_controlling_expr_is_constant
|
||||
|
||||
#include "tools/test/perf/testbench_output.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
|
||||
namespace perf {
|
||||
|
||||
typedef int (GemmProfileFunc)(
|
||||
TestbenchOutput <GemmProblem> &output,
|
||||
TestbenchOptions const &options,
|
||||
Config const &config);
|
||||
|
||||
void RegisterGemmProfileFunc(GemmProfileFunc*);
|
||||
|
||||
} // perf
|
||||
121
tools/test/perf/gemm/bmma_gemm.cu
Normal file
@ -0,0 +1,121 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright notice, this list of
|
||||
* conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||||
* conditions and the following disclaimer in the documentation and/or other materials
|
||||
* provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||||
* to endorse or promote products derived from this software without specific prior written
|
||||
* permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||||
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||||
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/// \file {nv-internal-release}
|
||||
|
||||
#if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750))
|
||||
#pragma warning( disable : 4503)
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/bmma_gemm_traits.h"
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch.h"
|
||||
#include "tools/test/perf/gemm/gemm_perf_testbed.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Traits>
|
||||
struct BmmaGemmDispatch {
|
||||
|
||||
typedef cutlass::gemm::Gemm<Traits> Gemm;
|
||||
|
||||
typedef typename Gemm::Params Params;
|
||||
|
||||
/// Indicate warp-level GEMM
|
||||
static bool const kThreadMultiplyAdd = false;
|
||||
|
||||
static bool const kRunCuBLAS = false;
|
||||
|
||||
static cutlass::MatrixLayout::Kind const kLayoutA = Traits::kLayoutA;
|
||||
static cutlass::MatrixLayout::Kind const kLayoutB = Traits::kLayoutB;
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Params argument
|
||||
Params params;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
BmmaGemmDispatch() {}
|
||||
|
||||
/// Initializes params object
|
||||
BmmaGemmDispatch(int m, int n, int k, int alpha,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_a, int lda,
|
||||
cutlass::Vector<cutlass::bin1_t, 32> const* d_b, int ldb, int beta,
|
||||
int const* d_c, int ldc, int* d_d, int ldd) {
|
||||
|
||||
params.initialize(m, n, k * 32, alpha, d_a, lda, d_b, ldb, beta, d_c, ldc, d_d, ldd);
|
||||
}
|
||||
|
||||
/// Initializes params object
|
||||
BmmaGemmDispatch(Params const& _params) : params(_params) {}
|
||||
|
||||
/// Launches kernel
|
||||
cudaError_t operator()() { return Gemm::launch(params); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int profile_bmma_gemm(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
typedef perf::GemmProfiler<cutlass::Vector<cutlass::bin1_t, 32>, cutlass::Vector<cutlass::bin1_t, 32>, int, int, int> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
{
|
||||
|
||||
typedef cutlass::gemm::BmmaGemmTraits<cutlass::Shape<1024, 128, 128>,
|
||||
cutlass::Shape<1024, 32, 32>,
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor>
|
||||
BmmaGemmTraits;
|
||||
|
||||
typedef BmmaGemmDispatch<BmmaGemmTraits> Dispatch;
|
||||
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "bmma_gemm_tn", options, config);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct BmmaGemmRegistrar {
|
||||
BmmaGemmRegistrar() { RegisterGemmProfileFunc(profile_bmma_gemm); }
|
||||
};
|
||||
|
||||
volatile BmmaGemmRegistrar _BmmaGemmRegistrar;
|
||||
|
||||
} // namespace perf
|
||||
|
||||
#endif // if (defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
|
||||
@ -24,8 +24,8 @@
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/matrix_traits.h>
|
||||
#include <tools/util/type_traits.h>
|
||||
#include "cutlass/matrix_traits.h"
|
||||
#include "tools/util/type_traits.h"
|
||||
|
||||
namespace perf {
|
||||
|
||||
|
||||
@ -32,7 +32,8 @@ template <typename Gemm_,
|
||||
typename ScalarD_,
|
||||
typename Compute_,
|
||||
typename ScalarEpilogue_,
|
||||
bool ThreadMultiplyAdd_>
|
||||
bool ThreadMultiplyAdd_,
|
||||
bool RunCuBLAS_ = true>
|
||||
struct CutlassDispatch {
|
||||
typedef typename Gemm_::Params Params;
|
||||
typedef Gemm_ Gemm;
|
||||
@ -45,6 +46,7 @@ struct CutlassDispatch {
|
||||
typedef ScalarEpilogue_ ScalarEpilogue;
|
||||
|
||||
static bool const kThreadMultiplyAdd = ThreadMultiplyAdd_;
|
||||
static bool const kRunCuBLAS = RunCuBLAS_;
|
||||
|
||||
static cutlass::MatrixLayout::Kind const kLayoutA = Gemm::Traits::kLayoutA;
|
||||
static cutlass::MatrixLayout::Kind const kLayoutB = Gemm::Traits::kLayoutB;
|
||||
@ -60,7 +62,7 @@ struct CutlassDispatch {
|
||||
// Methods
|
||||
//
|
||||
|
||||
CutlassDispatch() {}
|
||||
// CutlassDispatch() {}
|
||||
|
||||
/// Initializes params object
|
||||
CutlassDispatch(Index m,
|
||||
@ -84,33 +86,6 @@ struct CutlassDispatch {
|
||||
|
||||
/// Launches kernel
|
||||
cudaError_t operator()() { return Gemm::launch(params); }
|
||||
|
||||
/// Determines if problem is aligned (assuming no padding)
|
||||
static bool is_problem_aligned(
|
||||
int m,
|
||||
int n,
|
||||
int k) {
|
||||
|
||||
bool aligned = true;
|
||||
|
||||
if (kLayoutA == cutlass::MatrixLayout::kColumnMajor) {
|
||||
aligned = aligned && !(m % Gemm::Traits::GemmConfig::kScalarsPerLdgA);
|
||||
}
|
||||
else {
|
||||
aligned = aligned && !(k % Gemm::Traits::GemmConfig::kScalarsPerLdgA);
|
||||
}
|
||||
|
||||
if (kLayoutB == cutlass::MatrixLayout::kColumnMajor) {
|
||||
aligned = aligned && !(k % Gemm::Traits::GemmConfig::kScalarsPerLdgB);
|
||||
}
|
||||
else {
|
||||
aligned = aligned && !(n % Gemm::Traits::GemmConfig::kScalarsPerLdgB);
|
||||
}
|
||||
|
||||
aligned = aligned && !(m % Gemm::Traits::GemmConfig::kScalarsPerLdgC);
|
||||
|
||||
return aligned;
|
||||
}
|
||||
};
|
||||
|
||||
/// Basic dispatcher inferred from GEMM traits
|
||||
|
||||
@ -23,26 +23,29 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#include <cutlass/gemm/gemm.h>
|
||||
#include <cutlass/gemm/dgemm_traits.h>
|
||||
|
||||
#include <tools/test/perf/gemm/gemm_perf_testbed.h>
|
||||
|
||||
#include <tools/test/perf/gemm/gemm_profiler.h>
|
||||
#include <tools/test/perf/gemm/cutlass_dispatch.h>
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/dgemm_traits.h"
|
||||
|
||||
#include "tools/test/perf/cutlass_perf_test.h"
|
||||
#include "tools/test/perf/gemm/gemm_perf_testbed.h"
|
||||
#include "tools/test/perf/gemm/gemm_profiler.h"
|
||||
#include "tools/test/perf/gemm/cutlass_dispatch.h"
|
||||
#pragma warning( disable : 4503)
|
||||
namespace perf {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options) {
|
||||
|
||||
int profile_dgemm(TestbenchOutput<GemmProblem> &output, TestbenchOptions const &options, Config const &config) {
|
||||
typedef perf::GemmProfiler<double, double, double, double, double> GemmProfiler;
|
||||
|
||||
int results = 0;
|
||||
|
||||
if (!results) {
|
||||
|
||||
|
||||
// compute capability check
|
||||
if (!options.compute_capability(6, 0)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::DgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kRowMajor
|
||||
@ -50,11 +53,10 @@ int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options) {
|
||||
|
||||
typedef typename CutlassDispatchBasic<GemmTraits>::Dispatch Dispatch;
|
||||
|
||||
profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_nt", options);
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_nt", options, config);
|
||||
}
|
||||
|
||||
if (!results) {
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::DgemmTraits<
|
||||
cutlass::MatrixLayout::kColumnMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor
|
||||
@ -62,11 +64,10 @@ int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options) {
|
||||
|
||||
typedef typename CutlassDispatchBasic<GemmTraits>::Dispatch Dispatch;
|
||||
|
||||
profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_nn", options);
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_nn", options, config);
|
||||
}
|
||||
|
||||
if (!results) {
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::DgemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kColumnMajor
|
||||
@ -74,11 +75,10 @@ int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options) {
|
||||
|
||||
typedef typename CutlassDispatchBasic<GemmTraits>::Dispatch Dispatch;
|
||||
|
||||
profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_tn", options);
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_tn", options, config);
|
||||
}
|
||||
|
||||
if (!results) {
|
||||
|
||||
{
|
||||
typedef cutlass::gemm::DgemmTraits<
|
||||
cutlass::MatrixLayout::kRowMajor,
|
||||
cutlass::MatrixLayout::kRowMajor
|
||||
@ -86,12 +86,18 @@ int profile_dgemm(TestbenchOutput &output, TestbenchOptions const &options) {
|
||||
|
||||
typedef typename CutlassDispatchBasic<GemmTraits>::Dispatch Dispatch;
|
||||
|
||||
profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_tt", options);
|
||||
results |= profile_gemm<Dispatch, GemmProfiler>(output, "dgemm_tt", options, config);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
struct DgemmRegistrar {
|
||||
DgemmRegistrar() { RegisterGemmProfileFunc(profile_dgemm); }
|
||||
};
|
||||
|
||||
volatile DgemmRegistrar _DgemmRegistrar;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace perf
|
||||
|
||||